diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -78,7 +78,7 @@ def betas_for_alpha_bar(
|
|
78
78
|
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
79
79
|
def rescale_zero_terminal_snr(betas):
|
80
80
|
"""
|
81
|
-
Rescales betas to have zero terminal SNR Based on https://
|
81
|
+
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
82
82
|
|
83
83
|
|
84
84
|
Args:
|
@@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
230
230
|
timestep_spacing: str = "linspace",
|
231
231
|
steps_offset: int = 0,
|
232
232
|
rescale_betas_zero_snr: bool = False,
|
233
|
+
use_dynamic_shifting: bool = False,
|
234
|
+
time_shift_type: str = "exponential",
|
233
235
|
):
|
234
236
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
235
237
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
@@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
330
332
|
self,
|
331
333
|
num_inference_steps: int = None,
|
332
334
|
device: Union[str, torch.device] = None,
|
335
|
+
mu: Optional[float] = None,
|
333
336
|
timesteps: Optional[List[int]] = None,
|
334
337
|
):
|
335
338
|
"""
|
@@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
345
348
|
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
346
349
|
must be `None`, and `timestep_spacing` attribute will be ignored.
|
347
350
|
"""
|
351
|
+
if mu is not None:
|
352
|
+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
353
|
+
self.config.flow_shift = np.exp(mu)
|
348
354
|
if num_inference_steps is None and timesteps is None:
|
349
355
|
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
350
356
|
if num_inference_steps is not None and timesteps is not None:
|
@@ -366,7 +372,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
366
372
|
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
|
367
373
|
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
|
368
374
|
|
369
|
-
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://
|
375
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
370
376
|
if self.config.timestep_spacing == "linspace":
|
371
377
|
timesteps = (
|
372
378
|
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
|
@@ -460,7 +466,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
460
466
|
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
461
467
|
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
462
468
|
|
463
|
-
https://
|
469
|
+
https://huggingface.co/papers/2205.11487
|
464
470
|
"""
|
465
471
|
dtype = sample.dtype
|
466
472
|
batch_size, channels, *remaining_dims = sample.shape
|
@@ -646,7 +652,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
646
652
|
if len(args) > 1:
|
647
653
|
sample = args[1]
|
648
654
|
else:
|
649
|
-
raise ValueError("missing `sample` as a required
|
655
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
650
656
|
if timestep is not None:
|
651
657
|
deprecate(
|
652
658
|
"timesteps",
|
@@ -741,7 +747,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
741
747
|
if len(args) > 2:
|
742
748
|
sample = args[2]
|
743
749
|
else:
|
744
|
-
raise ValueError("
|
750
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
745
751
|
if timestep is not None:
|
746
752
|
deprecate(
|
747
753
|
"timesteps",
|
@@ -810,7 +816,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
810
816
|
if len(args) > 2:
|
811
817
|
sample = args[2]
|
812
818
|
else:
|
813
|
-
raise ValueError("
|
819
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
814
820
|
if timestep_list is not None:
|
815
821
|
deprecate(
|
816
822
|
"timestep_list",
|
@@ -845,7 +851,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
845
851
|
r0 = h_0 / h
|
846
852
|
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
847
853
|
if self.config.algorithm_type == "dpmsolver++":
|
848
|
-
# See https://
|
854
|
+
# See https://huggingface.co/papers/2211.01095 for detailed derivations
|
849
855
|
if self.config.solver_type == "midpoint":
|
850
856
|
x_t = (
|
851
857
|
(sigma_t / sigma_s0) * sample
|
@@ -859,7 +865,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
859
865
|
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
860
866
|
)
|
861
867
|
elif self.config.algorithm_type == "dpmsolver":
|
862
|
-
# See https://
|
868
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
863
869
|
if self.config.solver_type == "midpoint":
|
864
870
|
x_t = (
|
865
871
|
(alpha_t / alpha_s0) * sample
|
@@ -934,7 +940,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
934
940
|
if len(args) > 2:
|
935
941
|
sample = args[2]
|
936
942
|
else:
|
937
|
-
raise ValueError("
|
943
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
938
944
|
if timestep_list is not None:
|
939
945
|
deprecate(
|
940
946
|
"timestep_list",
|
@@ -975,7 +981,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
975
981
|
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
976
982
|
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
977
983
|
if self.config.algorithm_type == "dpmsolver++":
|
978
|
-
# See https://
|
984
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
979
985
|
x_t = (
|
980
986
|
(sigma_t / sigma_s0) * sample
|
981
987
|
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
@@ -983,7 +989,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
983
989
|
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
984
990
|
)
|
985
991
|
elif self.config.algorithm_type == "dpmsolver":
|
986
|
-
# See https://
|
992
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
987
993
|
x_t = (
|
988
994
|
(alpha_t / alpha_s0) * sample
|
989
995
|
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -80,14 +80,15 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
80
80
|
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
|
81
81
|
samples, and it can generate quite good samples even in only 10 steps.
|
82
82
|
|
83
|
-
For more details, see the original paper: https://
|
83
|
+
For more details, see the original paper: https://huggingface.co/papers/2206.00927 and
|
84
|
+
https://huggingface.co/papers/2211.01095
|
84
85
|
|
85
86
|
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
|
86
87
|
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
|
87
88
|
|
88
|
-
We also support the "dynamic thresholding" method in Imagen (https://
|
89
|
-
diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the
|
90
|
-
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
89
|
+
We also support the "dynamic thresholding" method in Imagen (https://huggingface.co/papers/2205.11487). For
|
90
|
+
pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the
|
91
|
+
dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
91
92
|
stable-diffusion).
|
92
93
|
|
93
94
|
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
@@ -95,7 +96,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
95
96
|
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
96
97
|
[`~SchedulerMixin.from_pretrained`] functions.
|
97
98
|
|
98
|
-
For more details, see the original paper: https://
|
99
|
+
For more details, see the original paper: https://huggingface.co/papers/2206.00927 and
|
100
|
+
https://huggingface.co/papers/2211.01095
|
99
101
|
|
100
102
|
Args:
|
101
103
|
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
@@ -113,21 +115,21 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
113
115
|
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
|
114
116
|
or `v-prediction`.
|
115
117
|
thresholding (`bool`, default `False`):
|
116
|
-
whether to use the "dynamic thresholding" method (introduced by Imagen,
|
117
|
-
For pixel-space diffusion models, you can set both
|
118
|
-
use the dynamic thresholding. Note that the
|
119
|
-
models (such as stable-diffusion).
|
118
|
+
whether to use the "dynamic thresholding" method (introduced by Imagen,
|
119
|
+
https://huggingface.co/papers/2205.11487). For pixel-space diffusion models, you can set both
|
120
|
+
`algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the
|
121
|
+
thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
|
120
122
|
dynamic_thresholding_ratio (`float`, default `0.995`):
|
121
123
|
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
122
|
-
(https://
|
124
|
+
(https://huggingface.co/papers/2205.11487).
|
123
125
|
sample_max_value (`float`, default `1.0`):
|
124
126
|
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
125
127
|
`algorithm_type="dpmsolver++`.
|
126
128
|
algorithm_type (`str`, default `dpmsolver++`):
|
127
129
|
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
|
128
|
-
algorithms in https://
|
129
|
-
https://
|
130
|
-
sampling (e.g. stable-diffusion).
|
130
|
+
algorithms in https://huggingface.co/papers/2206.00927, and the `dpmsolver++` type implements the
|
131
|
+
algorithms in https://huggingface.co/papers/2211.01095. We recommend to use `dpmsolver++` with
|
132
|
+
`solver_order=2` for guided sampling (e.g. stable-diffusion).
|
131
133
|
solver_type (`str`, default `midpoint`):
|
132
134
|
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
|
133
135
|
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
|
@@ -297,7 +299,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
297
299
|
)
|
298
300
|
|
299
301
|
if self.config.thresholding:
|
300
|
-
# Dynamic thresholding in https://
|
302
|
+
# Dynamic thresholding in https://huggingface.co/papers/2205.11487
|
301
303
|
dynamic_max_val = jnp.percentile(
|
302
304
|
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
|
303
305
|
)
|
@@ -335,7 +337,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
335
337
|
"""
|
336
338
|
One step for the first-order DPM-Solver (equivalent to DDIM).
|
337
339
|
|
338
|
-
See https://
|
340
|
+
See https://huggingface.co/papers/2206.00927 for the detailed derivation.
|
339
341
|
|
340
342
|
Args:
|
341
343
|
model_output (`jnp.ndarray`): direct output from learned diffusion model.
|
@@ -390,7 +392,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
390
392
|
r0 = h_0 / h
|
391
393
|
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
392
394
|
if self.config.algorithm_type == "dpmsolver++":
|
393
|
-
# See https://
|
395
|
+
# See https://huggingface.co/papers/2211.01095 for detailed derivations
|
394
396
|
if self.config.solver_type == "midpoint":
|
395
397
|
x_t = (
|
396
398
|
(sigma_t / sigma_s0) * sample
|
@@ -404,7 +406,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
404
406
|
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
|
405
407
|
)
|
406
408
|
elif self.config.algorithm_type == "dpmsolver":
|
407
|
-
# See https://
|
409
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
408
410
|
if self.config.solver_type == "midpoint":
|
409
411
|
x_t = (
|
410
412
|
(alpha_t / alpha_s0) * sample
|
@@ -458,7 +460,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
458
460
|
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
459
461
|
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
460
462
|
if self.config.algorithm_type == "dpmsolver++":
|
461
|
-
# See https://
|
463
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
462
464
|
x_t = (
|
463
465
|
(sigma_t / sigma_s0) * sample
|
464
466
|
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
|
@@ -466,7 +468,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
466
468
|
- (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
467
469
|
)
|
468
470
|
elif self.config.algorithm_type == "dpmsolver":
|
469
|
-
# See https://
|
471
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
470
472
|
x_t = (
|
471
473
|
(alpha_t / alpha_s0) * sample
|
472
474
|
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -257,7 +257,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
257
257
|
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped).item()
|
258
258
|
self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx
|
259
259
|
|
260
|
-
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://
|
260
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
261
261
|
if self.config.timestep_spacing == "linspace":
|
262
262
|
timesteps = (
|
263
263
|
np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64)
|
@@ -338,7 +338,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
338
338
|
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
339
339
|
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
340
340
|
|
341
|
-
https://
|
341
|
+
https://huggingface.co/papers/2205.11487
|
342
342
|
"""
|
343
343
|
dtype = sample.dtype
|
344
344
|
batch_size, channels, *remaining_dims = sample.shape
|
@@ -513,7 +513,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
513
513
|
if len(args) > 1:
|
514
514
|
sample = args[1]
|
515
515
|
else:
|
516
|
-
raise ValueError("missing `sample` as a required
|
516
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
517
517
|
if timestep is not None:
|
518
518
|
deprecate(
|
519
519
|
"timesteps",
|
@@ -609,7 +609,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
609
609
|
if len(args) > 2:
|
610
610
|
sample = args[2]
|
611
611
|
else:
|
612
|
-
raise ValueError("
|
612
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
613
613
|
if timestep is not None:
|
614
614
|
deprecate(
|
615
615
|
"timesteps",
|
@@ -679,7 +679,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
679
679
|
if len(args) > 2:
|
680
680
|
sample = args[2]
|
681
681
|
else:
|
682
|
-
raise ValueError("
|
682
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
683
683
|
if timestep_list is not None:
|
684
684
|
deprecate(
|
685
685
|
"timestep_list",
|
@@ -714,7 +714,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
714
714
|
r0 = h_0 / h
|
715
715
|
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
716
716
|
if self.config.algorithm_type == "dpmsolver++":
|
717
|
-
# See https://
|
717
|
+
# See https://huggingface.co/papers/2211.01095 for detailed derivations
|
718
718
|
if self.config.solver_type == "midpoint":
|
719
719
|
x_t = (
|
720
720
|
(sigma_t / sigma_s0) * sample
|
@@ -728,7 +728,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
728
728
|
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
729
729
|
)
|
730
730
|
elif self.config.algorithm_type == "dpmsolver":
|
731
|
-
# See https://
|
731
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
732
732
|
if self.config.solver_type == "midpoint":
|
733
733
|
x_t = (
|
734
734
|
(alpha_t / alpha_s0) * sample
|
@@ -804,7 +804,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
804
804
|
if len(args) > 2:
|
805
805
|
sample = args[2]
|
806
806
|
else:
|
807
|
-
raise ValueError("
|
807
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
808
808
|
if timestep_list is not None:
|
809
809
|
deprecate(
|
810
810
|
"timestep_list",
|
@@ -845,7 +845,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
845
845
|
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
846
846
|
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
847
847
|
if self.config.algorithm_type == "dpmsolver++":
|
848
|
-
# See https://
|
848
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
849
849
|
x_t = (
|
850
850
|
(sigma_t / sigma_s0) * sample
|
851
851
|
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
@@ -853,7 +853,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
853
853
|
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
854
854
|
)
|
855
855
|
elif self.config.algorithm_type == "dpmsolver":
|
856
|
-
# See https://
|
856
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
857
857
|
x_t = (
|
858
858
|
(alpha_t / alpha_s0) * sample
|
859
859
|
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -352,7 +352,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
352
352
|
|
353
353
|
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
354
354
|
|
355
|
-
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://
|
355
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
356
356
|
if self.config.timestep_spacing == "linspace":
|
357
357
|
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
358
358
|
elif self.config.timestep_spacing == "leading":
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
169
169
|
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
170
170
|
lambda_min_clipped: float = -float("inf"),
|
171
171
|
variance_type: Optional[str] = None,
|
172
|
+
use_dynamic_shifting: bool = False,
|
173
|
+
time_shift_type: str = "exponential",
|
172
174
|
):
|
173
175
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
174
176
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
@@ -218,7 +220,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
218
220
|
|
219
221
|
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
220
222
|
raise ValueError(
|
221
|
-
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please
|
223
|
+
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
222
224
|
)
|
223
225
|
|
224
226
|
# setable values
|
@@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
301
303
|
self,
|
302
304
|
num_inference_steps: int = None,
|
303
305
|
device: Union[str, torch.device] = None,
|
306
|
+
mu: Optional[float] = None,
|
304
307
|
timesteps: Optional[List[int]] = None,
|
305
308
|
):
|
306
309
|
"""
|
@@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
316
319
|
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
|
317
320
|
passed, `num_inference_steps` must be `None`.
|
318
321
|
"""
|
322
|
+
if mu is not None:
|
323
|
+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
324
|
+
self.config.flow_shift = np.exp(mu)
|
319
325
|
if num_inference_steps is None and timesteps is None:
|
320
326
|
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
321
327
|
if num_inference_steps is not None and timesteps is not None:
|
@@ -410,7 +416,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
410
416
|
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
411
417
|
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
412
418
|
|
413
|
-
https://
|
419
|
+
https://huggingface.co/papers/2205.11487
|
414
420
|
"""
|
415
421
|
dtype = sample.dtype
|
416
422
|
batch_size, channels, *remaining_dims = sample.shape
|
@@ -584,7 +590,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
584
590
|
if len(args) > 1:
|
585
591
|
sample = args[1]
|
586
592
|
else:
|
587
|
-
raise ValueError("missing `sample` as a required
|
593
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
588
594
|
if timestep is not None:
|
589
595
|
deprecate(
|
590
596
|
"timesteps",
|
@@ -681,7 +687,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
681
687
|
if len(args) > 2:
|
682
688
|
sample = args[2]
|
683
689
|
else:
|
684
|
-
raise ValueError("
|
690
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
685
691
|
if timestep is not None:
|
686
692
|
deprecate(
|
687
693
|
"timesteps",
|
@@ -746,7 +752,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
746
752
|
if len(args) > 2:
|
747
753
|
sample = args[2]
|
748
754
|
else:
|
749
|
-
raise ValueError("
|
755
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
750
756
|
if timestep_list is not None:
|
751
757
|
deprecate(
|
752
758
|
"timestep_list",
|
@@ -780,7 +786,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
780
786
|
r0 = h_0 / h
|
781
787
|
D0, D1 = m1, (1.0 / r0) * (m0 - m1)
|
782
788
|
if self.config.algorithm_type == "dpmsolver++":
|
783
|
-
# See https://
|
789
|
+
# See https://huggingface.co/papers/2211.01095 for detailed derivations
|
784
790
|
if self.config.solver_type == "midpoint":
|
785
791
|
x_t = (
|
786
792
|
(sigma_t / sigma_s1) * sample
|
@@ -794,7 +800,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
794
800
|
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
795
801
|
)
|
796
802
|
elif self.config.algorithm_type == "dpmsolver":
|
797
|
-
# See https://
|
803
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
798
804
|
if self.config.solver_type == "midpoint":
|
799
805
|
x_t = (
|
800
806
|
(alpha_t / alpha_s1) * sample
|
@@ -858,7 +864,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
858
864
|
if len(args) > 2:
|
859
865
|
sample = args[2]
|
860
866
|
else:
|
861
|
-
raise ValueError("
|
867
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
862
868
|
if timestep_list is not None:
|
863
869
|
deprecate(
|
864
870
|
"timestep_list",
|
@@ -899,7 +905,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
899
905
|
D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1)
|
900
906
|
D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1)
|
901
907
|
if self.config.algorithm_type == "dpmsolver++":
|
902
|
-
# See https://
|
908
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
903
909
|
if self.config.solver_type == "midpoint":
|
904
910
|
x_t = (
|
905
911
|
(sigma_t / sigma_s2) * sample
|
@@ -914,7 +920,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
914
920
|
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
915
921
|
)
|
916
922
|
elif self.config.algorithm_type == "dpmsolver":
|
917
|
-
# See https://
|
923
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
918
924
|
if self.config.solver_type == "midpoint":
|
919
925
|
x_t = (
|
920
926
|
(alpha_t / alpha_s2) * sample
|
@@ -981,12 +987,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
981
987
|
if len(args) > 2:
|
982
988
|
sample = args[2]
|
983
989
|
else:
|
984
|
-
raise ValueError("
|
990
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
985
991
|
if order is None:
|
986
992
|
if len(args) > 3:
|
987
993
|
order = args[3]
|
988
994
|
else:
|
989
|
-
raise ValueError("
|
995
|
+
raise ValueError("missing `order` as a required keyword argument")
|
990
996
|
if timestep_list is not None:
|
991
997
|
deprecate(
|
992
998
|
"timestep_list",
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -31,7 +31,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
31
31
|
`EDMDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
32
32
|
|
33
33
|
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
|
34
|
-
https://
|
34
|
+
https://huggingface.co/papers/2206.00364
|
35
35
|
|
36
36
|
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
37
37
|
methods the library implements for all schedulers such as loading and saving.
|
@@ -47,8 +47,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
47
47
|
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
48
48
|
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
49
49
|
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
50
|
-
(https://
|
51
|
-
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
50
|
+
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
|
51
|
+
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
52
52
|
num_train_timesteps (`int`, defaults to 1000):
|
53
53
|
The number of diffusion steps to train the model.
|
54
54
|
solver_order (`int`, defaults to 2):
|
@@ -176,7 +176,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
176
176
|
|
177
177
|
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
178
178
|
def precondition_inputs(self, sample, sigma):
|
179
|
-
c_in =
|
179
|
+
c_in = self._get_conditioning_c_in(sigma)
|
180
180
|
scaled_sample = sample * c_in
|
181
181
|
return scaled_sample
|
182
182
|
|
@@ -305,7 +305,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
305
305
|
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
306
306
|
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
307
307
|
|
308
|
-
https://
|
308
|
+
https://huggingface.co/papers/2205.11487
|
309
309
|
"""
|
310
310
|
dtype = sample.dtype
|
311
311
|
batch_size, channels, *remaining_dims = sample.shape
|
@@ -472,7 +472,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
472
472
|
r0 = h_0 / h
|
473
473
|
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
474
474
|
if self.config.algorithm_type == "dpmsolver++":
|
475
|
-
# See https://
|
475
|
+
# See https://huggingface.co/papers/2211.01095 for detailed derivations
|
476
476
|
if self.config.solver_type == "midpoint":
|
477
477
|
x_t = (
|
478
478
|
(sigma_t / sigma_s0) * sample
|
@@ -548,7 +548,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
548
548
|
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
549
549
|
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
550
550
|
if self.config.algorithm_type == "dpmsolver++":
|
551
|
-
# See https://
|
551
|
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
552
552
|
x_t = (
|
553
553
|
(sigma_t / sigma_s0) * sample
|
554
554
|
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
@@ -703,5 +703,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
703
703
|
noisy_samples = original_samples + noise * sigma
|
704
704
|
return noisy_samples
|
705
705
|
|
706
|
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
707
|
+
def _get_conditioning_c_in(self, sigma):
|
708
|
+
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
709
|
+
return c_in
|
710
|
+
|
706
711
|
def __len__(self):
|
707
712
|
return self.config.num_train_timesteps
|