diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
|
16
16
|
|
17
|
+
import math
|
17
18
|
from typing import List, Optional, Tuple, Union
|
18
19
|
|
19
20
|
import numpy as np
|
@@ -44,6 +45,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
44
45
|
range is [0.2, 80.0].
|
45
46
|
sigma_data (`float`, *optional*, defaults to 0.5):
|
46
47
|
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
48
|
+
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
49
|
+
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
50
|
+
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
|
51
|
+
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
47
52
|
num_train_timesteps (`int`, defaults to 1000):
|
48
53
|
The number of diffusion steps to train the model.
|
49
54
|
solver_order (`int`, defaults to 2):
|
@@ -62,10 +67,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
62
67
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
63
68
|
`algorithm_type="dpmsolver++"`.
|
64
69
|
algorithm_type (`str`, defaults to `dpmsolver++`):
|
65
|
-
Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The
|
66
|
-
|
67
|
-
|
68
|
-
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
70
|
+
Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements
|
71
|
+
the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to
|
72
|
+
use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
69
73
|
solver_type (`str`, defaults to `midpoint`):
|
70
74
|
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
71
75
|
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
@@ -77,8 +81,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
77
81
|
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
78
82
|
steps, but sometimes may result in blurring.
|
79
83
|
final_sigmas_type (`str`, defaults to `"zero"`):
|
80
|
-
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
81
|
-
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
84
|
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
85
|
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
82
86
|
"""
|
83
87
|
|
84
88
|
_compatibles = []
|
@@ -90,6 +94,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
90
94
|
sigma_min: float = 0.002,
|
91
95
|
sigma_max: float = 80.0,
|
92
96
|
sigma_data: float = 0.5,
|
97
|
+
sigma_schedule: str = "karras",
|
93
98
|
num_train_timesteps: int = 1000,
|
94
99
|
prediction_type: str = "epsilon",
|
95
100
|
rho: float = 7.0,
|
@@ -114,7 +119,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
114
119
|
if solver_type in ["logrho", "bh1", "bh2"]:
|
115
120
|
self.register_to_config(solver_type="midpoint")
|
116
121
|
else:
|
117
|
-
raise NotImplementedError(f"{solver_type}
|
122
|
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
118
123
|
|
119
124
|
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
120
125
|
raise ValueError(
|
@@ -122,7 +127,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
122
127
|
)
|
123
128
|
|
124
129
|
ramp = torch.linspace(0, 1, num_train_timesteps)
|
125
|
-
|
130
|
+
if sigma_schedule == "karras":
|
131
|
+
sigmas = self._compute_karras_sigmas(ramp)
|
132
|
+
elif sigma_schedule == "exponential":
|
133
|
+
sigmas = self._compute_exponential_sigmas(ramp)
|
134
|
+
|
126
135
|
self.timesteps = self.precondition_noise(sigmas)
|
127
136
|
|
128
137
|
self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
@@ -143,7 +152,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
143
152
|
@property
|
144
153
|
def step_index(self):
|
145
154
|
"""
|
146
|
-
The index counter for current timestep. It will
|
155
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
147
156
|
"""
|
148
157
|
return self._step_index
|
149
158
|
|
@@ -197,21 +206,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
197
206
|
return denoised
|
198
207
|
|
199
208
|
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
|
200
|
-
def scale_model_input(
|
201
|
-
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
202
|
-
) -> torch.FloatTensor:
|
209
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
203
210
|
"""
|
204
211
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
205
212
|
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
206
213
|
|
207
214
|
Args:
|
208
|
-
sample (`torch.
|
215
|
+
sample (`torch.Tensor`):
|
209
216
|
The input sample.
|
210
217
|
timestep (`int`, *optional*):
|
211
218
|
The current timestep in the diffusion chain.
|
212
219
|
|
213
220
|
Returns:
|
214
|
-
`torch.
|
221
|
+
`torch.Tensor`:
|
215
222
|
A scaled input sample.
|
216
223
|
"""
|
217
224
|
if self.step_index is None:
|
@@ -237,7 +244,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
237
244
|
self.num_inference_steps = num_inference_steps
|
238
245
|
|
239
246
|
ramp = np.linspace(0, 1, self.num_inference_steps)
|
240
|
-
|
247
|
+
if self.config.sigma_schedule == "karras":
|
248
|
+
sigmas = self._compute_karras_sigmas(ramp)
|
249
|
+
elif self.config.sigma_schedule == "exponential":
|
250
|
+
sigmas = self._compute_exponential_sigmas(ramp)
|
241
251
|
|
242
252
|
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
243
253
|
self.timesteps = self.precondition_noise(sigmas)
|
@@ -263,10 +273,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
263
273
|
self._begin_index = None
|
264
274
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
265
275
|
|
266
|
-
#
|
267
|
-
def
|
276
|
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
|
277
|
+
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
268
278
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
269
|
-
|
270
279
|
sigma_min = sigma_min or self.config.sigma_min
|
271
280
|
sigma_max = sigma_max or self.config.sigma_max
|
272
281
|
|
@@ -274,10 +283,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
274
283
|
min_inv_rho = sigma_min ** (1 / rho)
|
275
284
|
max_inv_rho = sigma_max ** (1 / rho)
|
276
285
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
286
|
+
|
287
|
+
return sigmas
|
288
|
+
|
289
|
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
|
290
|
+
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
291
|
+
"""Implementation closely follows k-diffusion.
|
292
|
+
|
293
|
+
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
294
|
+
"""
|
295
|
+
sigma_min = sigma_min or self.config.sigma_min
|
296
|
+
sigma_max = sigma_max or self.config.sigma_max
|
297
|
+
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
|
277
298
|
return sigmas
|
278
299
|
|
279
300
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
280
|
-
def _threshold_sample(self, sample: torch.
|
301
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
281
302
|
"""
|
282
303
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
283
304
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -342,9 +363,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
342
363
|
|
343
364
|
def convert_model_output(
|
344
365
|
self,
|
345
|
-
model_output: torch.
|
346
|
-
sample: torch.
|
347
|
-
) -> torch.
|
366
|
+
model_output: torch.Tensor,
|
367
|
+
sample: torch.Tensor = None,
|
368
|
+
) -> torch.Tensor:
|
348
369
|
"""
|
349
370
|
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
350
371
|
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
@@ -358,13 +379,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
358
379
|
</Tip>
|
359
380
|
|
360
381
|
Args:
|
361
|
-
model_output (`torch.
|
382
|
+
model_output (`torch.Tensor`):
|
362
383
|
The direct output from the learned diffusion model.
|
363
|
-
sample (`torch.
|
384
|
+
sample (`torch.Tensor`):
|
364
385
|
A current instance of a sample created by the diffusion process.
|
365
386
|
|
366
387
|
Returns:
|
367
|
-
`torch.
|
388
|
+
`torch.Tensor`:
|
368
389
|
The converted model output.
|
369
390
|
"""
|
370
391
|
sigma = self.sigmas[self.step_index]
|
@@ -377,21 +398,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
377
398
|
|
378
399
|
def dpm_solver_first_order_update(
|
379
400
|
self,
|
380
|
-
model_output: torch.
|
381
|
-
sample: torch.
|
382
|
-
noise: Optional[torch.
|
383
|
-
) -> torch.
|
401
|
+
model_output: torch.Tensor,
|
402
|
+
sample: torch.Tensor = None,
|
403
|
+
noise: Optional[torch.Tensor] = None,
|
404
|
+
) -> torch.Tensor:
|
384
405
|
"""
|
385
406
|
One step for the first-order DPMSolver (equivalent to DDIM).
|
386
407
|
|
387
408
|
Args:
|
388
|
-
model_output (`torch.
|
409
|
+
model_output (`torch.Tensor`):
|
389
410
|
The direct output from the learned diffusion model.
|
390
|
-
sample (`torch.
|
411
|
+
sample (`torch.Tensor`):
|
391
412
|
A current instance of a sample created by the diffusion process.
|
392
413
|
|
393
414
|
Returns:
|
394
|
-
`torch.
|
415
|
+
`torch.Tensor`:
|
395
416
|
The sample tensor at the previous timestep.
|
396
417
|
"""
|
397
418
|
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
@@ -415,21 +436,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
415
436
|
|
416
437
|
def multistep_dpm_solver_second_order_update(
|
417
438
|
self,
|
418
|
-
model_output_list: List[torch.
|
419
|
-
sample: torch.
|
420
|
-
noise: Optional[torch.
|
421
|
-
) -> torch.
|
439
|
+
model_output_list: List[torch.Tensor],
|
440
|
+
sample: torch.Tensor = None,
|
441
|
+
noise: Optional[torch.Tensor] = None,
|
442
|
+
) -> torch.Tensor:
|
422
443
|
"""
|
423
444
|
One step for the second-order multistep DPMSolver.
|
424
445
|
|
425
446
|
Args:
|
426
|
-
model_output_list (`List[torch.
|
447
|
+
model_output_list (`List[torch.Tensor]`):
|
427
448
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
428
|
-
sample (`torch.
|
449
|
+
sample (`torch.Tensor`):
|
429
450
|
A current instance of a sample created by the diffusion process.
|
430
451
|
|
431
452
|
Returns:
|
432
|
-
`torch.
|
453
|
+
`torch.Tensor`:
|
433
454
|
The sample tensor at the previous timestep.
|
434
455
|
"""
|
435
456
|
sigma_t, sigma_s0, sigma_s1 = (
|
@@ -486,20 +507,20 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
486
507
|
|
487
508
|
def multistep_dpm_solver_third_order_update(
|
488
509
|
self,
|
489
|
-
model_output_list: List[torch.
|
490
|
-
sample: torch.
|
491
|
-
) -> torch.
|
510
|
+
model_output_list: List[torch.Tensor],
|
511
|
+
sample: torch.Tensor = None,
|
512
|
+
) -> torch.Tensor:
|
492
513
|
"""
|
493
514
|
One step for the third-order multistep DPMSolver.
|
494
515
|
|
495
516
|
Args:
|
496
|
-
model_output_list (`List[torch.
|
517
|
+
model_output_list (`List[torch.Tensor]`):
|
497
518
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
498
|
-
sample (`torch.
|
519
|
+
sample (`torch.Tensor`):
|
499
520
|
A current instance of a sample created by diffusion process.
|
500
521
|
|
501
522
|
Returns:
|
502
|
-
`torch.
|
523
|
+
`torch.Tensor`:
|
503
524
|
The sample tensor at the previous timestep.
|
504
525
|
"""
|
505
526
|
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
@@ -573,9 +594,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
573
594
|
|
574
595
|
def step(
|
575
596
|
self,
|
576
|
-
model_output: torch.
|
597
|
+
model_output: torch.Tensor,
|
577
598
|
timestep: int,
|
578
|
-
sample: torch.
|
599
|
+
sample: torch.Tensor,
|
579
600
|
generator=None,
|
580
601
|
return_dict: bool = True,
|
581
602
|
) -> Union[SchedulerOutput, Tuple]:
|
@@ -584,11 +605,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
584
605
|
the multistep DPMSolver.
|
585
606
|
|
586
607
|
Args:
|
587
|
-
model_output (`torch.
|
608
|
+
model_output (`torch.Tensor`):
|
588
609
|
The direct output from learned diffusion model.
|
589
610
|
timestep (`int`):
|
590
611
|
The current discrete timestep in the diffusion chain.
|
591
|
-
sample (`torch.
|
612
|
+
sample (`torch.Tensor`):
|
592
613
|
A current instance of a sample created by the diffusion process.
|
593
614
|
generator (`torch.Generator`, *optional*):
|
594
615
|
A random number generator.
|
@@ -652,10 +673,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
652
673
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
653
674
|
def add_noise(
|
654
675
|
self,
|
655
|
-
original_samples: torch.
|
656
|
-
noise: torch.
|
657
|
-
timesteps: torch.
|
658
|
-
) -> torch.
|
676
|
+
original_samples: torch.Tensor,
|
677
|
+
noise: torch.Tensor,
|
678
|
+
timesteps: torch.Tensor,
|
679
|
+
) -> torch.Tensor:
|
659
680
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
660
681
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
661
682
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -669,7 +690,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
669
690
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
670
691
|
if self.begin_index is None:
|
671
692
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
693
|
+
elif self.step_index is not None:
|
694
|
+
# add_noise is called after first denoising step (for inpainting)
|
695
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
672
696
|
else:
|
697
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
673
698
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
674
699
|
|
675
700
|
sigma = sigmas[step_indices].flatten()
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import math
|
15
16
|
from dataclasses import dataclass
|
16
17
|
from typing import Optional, Tuple, Union
|
17
18
|
|
@@ -34,16 +35,16 @@ class EDMEulerSchedulerOutput(BaseOutput):
|
|
34
35
|
Output class for the scheduler's `step` function output.
|
35
36
|
|
36
37
|
Args:
|
37
|
-
prev_sample (`torch.
|
38
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
38
39
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
39
40
|
denoising loop.
|
40
|
-
pred_original_sample (`torch.
|
41
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41
42
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
42
43
|
`pred_original_sample` can be used to preview progress or for guidance.
|
43
44
|
"""
|
44
45
|
|
45
|
-
prev_sample: torch.
|
46
|
-
pred_original_sample: Optional[torch.
|
46
|
+
prev_sample: torch.Tensor
|
47
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
47
48
|
|
48
49
|
|
49
50
|
class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
@@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
65
66
|
range is [0.2, 80.0].
|
66
67
|
sigma_data (`float`, *optional*, defaults to 0.5):
|
67
68
|
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
69
|
+
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
70
|
+
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
71
|
+
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
|
72
|
+
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
68
73
|
num_train_timesteps (`int`, defaults to 1000):
|
69
74
|
The number of diffusion steps to train the model.
|
70
75
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
@@ -84,15 +89,23 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
84
89
|
sigma_min: float = 0.002,
|
85
90
|
sigma_max: float = 80.0,
|
86
91
|
sigma_data: float = 0.5,
|
92
|
+
sigma_schedule: str = "karras",
|
87
93
|
num_train_timesteps: int = 1000,
|
88
94
|
prediction_type: str = "epsilon",
|
89
95
|
rho: float = 7.0,
|
90
96
|
):
|
97
|
+
if sigma_schedule not in ["karras", "exponential"]:
|
98
|
+
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
|
99
|
+
|
91
100
|
# setable values
|
92
101
|
self.num_inference_steps = None
|
93
102
|
|
94
103
|
ramp = torch.linspace(0, 1, num_train_timesteps)
|
95
|
-
|
104
|
+
if sigma_schedule == "karras":
|
105
|
+
sigmas = self._compute_karras_sigmas(ramp)
|
106
|
+
elif sigma_schedule == "exponential":
|
107
|
+
sigmas = self._compute_exponential_sigmas(ramp)
|
108
|
+
|
96
109
|
self.timesteps = self.precondition_noise(sigmas)
|
97
110
|
|
98
111
|
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
@@ -111,7 +124,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
111
124
|
@property
|
112
125
|
def step_index(self):
|
113
126
|
"""
|
114
|
-
The index counter for current timestep. It will
|
127
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
115
128
|
"""
|
116
129
|
return self._step_index
|
117
130
|
|
@@ -161,21 +174,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
161
174
|
|
162
175
|
return denoised
|
163
176
|
|
164
|
-
def scale_model_input(
|
165
|
-
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
166
|
-
) -> torch.FloatTensor:
|
177
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
167
178
|
"""
|
168
179
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
169
180
|
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
170
181
|
|
171
182
|
Args:
|
172
|
-
sample (`torch.
|
183
|
+
sample (`torch.Tensor`):
|
173
184
|
The input sample.
|
174
185
|
timestep (`int`, *optional*):
|
175
186
|
The current timestep in the diffusion chain.
|
176
187
|
|
177
188
|
Returns:
|
178
|
-
`torch.
|
189
|
+
`torch.Tensor`:
|
179
190
|
A scaled input sample.
|
180
191
|
"""
|
181
192
|
if self.step_index is None:
|
@@ -200,7 +211,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
200
211
|
self.num_inference_steps = num_inference_steps
|
201
212
|
|
202
213
|
ramp = np.linspace(0, 1, self.num_inference_steps)
|
203
|
-
|
214
|
+
if self.config.sigma_schedule == "karras":
|
215
|
+
sigmas = self._compute_karras_sigmas(ramp)
|
216
|
+
elif self.config.sigma_schedule == "exponential":
|
217
|
+
sigmas = self._compute_exponential_sigmas(ramp)
|
204
218
|
|
205
219
|
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
206
220
|
self.timesteps = self.precondition_noise(sigmas)
|
@@ -211,9 +225,8 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
211
225
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
212
226
|
|
213
227
|
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
214
|
-
def
|
228
|
+
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
215
229
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
216
|
-
|
217
230
|
sigma_min = sigma_min or self.config.sigma_min
|
218
231
|
sigma_max = sigma_max or self.config.sigma_max
|
219
232
|
|
@@ -221,6 +234,17 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
221
234
|
min_inv_rho = sigma_min ** (1 / rho)
|
222
235
|
max_inv_rho = sigma_max ** (1 / rho)
|
223
236
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
237
|
+
|
238
|
+
return sigmas
|
239
|
+
|
240
|
+
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
241
|
+
"""Implementation closely follows k-diffusion.
|
242
|
+
|
243
|
+
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
244
|
+
"""
|
245
|
+
sigma_min = sigma_min or self.config.sigma_min
|
246
|
+
sigma_max = sigma_max or self.config.sigma_max
|
247
|
+
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
|
224
248
|
return sigmas
|
225
249
|
|
226
250
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
@@ -249,9 +273,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
249
273
|
|
250
274
|
def step(
|
251
275
|
self,
|
252
|
-
model_output: torch.
|
253
|
-
timestep: Union[float, torch.
|
254
|
-
sample: torch.
|
276
|
+
model_output: torch.Tensor,
|
277
|
+
timestep: Union[float, torch.Tensor],
|
278
|
+
sample: torch.Tensor,
|
255
279
|
s_churn: float = 0.0,
|
256
280
|
s_tmin: float = 0.0,
|
257
281
|
s_tmax: float = float("inf"),
|
@@ -264,11 +288,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
264
288
|
process from the learned model outputs (most often the predicted noise).
|
265
289
|
|
266
290
|
Args:
|
267
|
-
model_output (`torch.
|
291
|
+
model_output (`torch.Tensor`):
|
268
292
|
The direct output from learned diffusion model.
|
269
293
|
timestep (`float`):
|
270
294
|
The current discrete timestep in the diffusion chain.
|
271
|
-
sample (`torch.
|
295
|
+
sample (`torch.Tensor`):
|
272
296
|
A current instance of a sample created by the diffusion process.
|
273
297
|
s_churn (`float`):
|
274
298
|
s_tmin (`float`):
|
@@ -278,8 +302,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
278
302
|
generator (`torch.Generator`, *optional*):
|
279
303
|
A random number generator.
|
280
304
|
return_dict (`bool`):
|
281
|
-
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or
|
282
|
-
tuple.
|
305
|
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
|
283
306
|
|
284
307
|
Returns:
|
285
308
|
[`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
|
@@ -287,11 +310,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
287
310
|
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
288
311
|
"""
|
289
312
|
|
290
|
-
if (
|
291
|
-
isinstance(timestep, int)
|
292
|
-
or isinstance(timestep, torch.IntTensor)
|
293
|
-
or isinstance(timestep, torch.LongTensor)
|
294
|
-
):
|
313
|
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
295
314
|
raise ValueError(
|
296
315
|
(
|
297
316
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
@@ -350,10 +369,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
350
369
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
351
370
|
def add_noise(
|
352
371
|
self,
|
353
|
-
original_samples: torch.
|
354
|
-
noise: torch.
|
355
|
-
timesteps: torch.
|
356
|
-
) -> torch.
|
372
|
+
original_samples: torch.Tensor,
|
373
|
+
noise: torch.Tensor,
|
374
|
+
timesteps: torch.Tensor,
|
375
|
+
) -> torch.Tensor:
|
357
376
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
358
377
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
359
378
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -367,7 +386,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
367
386
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
368
387
|
if self.begin_index is None:
|
369
388
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
389
|
+
elif self.step_index is not None:
|
390
|
+
# add_noise is called after first denoising step (for inpainting)
|
391
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
370
392
|
else:
|
393
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
371
394
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
372
395
|
|
373
396
|
sigma = sigmas[step_indices].flatten()
|