diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -21,11 +21,15 @@ import numpy as np
|
|
21
21
|
import torch
|
22
22
|
|
23
23
|
from ..configuration_utils import ConfigMixin, register_to_config
|
24
|
-
from ..utils import deprecate
|
24
|
+
from ..utils import deprecate, is_scipy_available
|
25
25
|
from ..utils.torch_utils import randn_tensor
|
26
26
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
27
27
|
|
28
28
|
|
29
|
+
if is_scipy_available():
|
30
|
+
import scipy.stats
|
31
|
+
|
32
|
+
|
29
33
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
30
34
|
def betas_for_alpha_bar(
|
31
35
|
num_diffusion_timesteps,
|
@@ -161,6 +165,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
161
165
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
162
166
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
163
167
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
168
|
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
169
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
170
|
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
171
|
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
172
|
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
164
173
|
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
|
165
174
|
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
166
175
|
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
@@ -206,7 +215,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
206
215
|
lower_order_final: bool = True,
|
207
216
|
euler_at_final: bool = False,
|
208
217
|
use_karras_sigmas: Optional[bool] = False,
|
218
|
+
use_exponential_sigmas: Optional[bool] = False,
|
219
|
+
use_beta_sigmas: Optional[bool] = False,
|
209
220
|
use_lu_lambdas: Optional[bool] = False,
|
221
|
+
use_flow_sigmas: Optional[bool] = False,
|
222
|
+
flow_shift: Optional[float] = 1.0,
|
210
223
|
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
211
224
|
lambda_min_clipped: float = -float("inf"),
|
212
225
|
variance_type: Optional[str] = None,
|
@@ -214,6 +227,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
214
227
|
steps_offset: int = 0,
|
215
228
|
rescale_betas_zero_snr: bool = False,
|
216
229
|
):
|
230
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
231
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
232
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
233
|
+
raise ValueError(
|
234
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
235
|
+
)
|
217
236
|
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
218
237
|
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
219
238
|
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
@@ -330,6 +349,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
330
349
|
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
|
331
350
|
if timesteps is not None and self.config.use_lu_lambdas:
|
332
351
|
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
|
352
|
+
if timesteps is not None and self.config.use_exponential_sigmas:
|
353
|
+
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
354
|
+
if timesteps is not None and self.config.use_beta_sigmas:
|
355
|
+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
333
356
|
|
334
357
|
if timesteps is not None:
|
335
358
|
timesteps = np.array(timesteps).astype(np.int64)
|
@@ -378,6 +401,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
378
401
|
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
|
379
402
|
sigmas = np.exp(lambdas)
|
380
403
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
404
|
+
elif self.config.use_exponential_sigmas:
|
405
|
+
sigmas = np.flip(sigmas).copy()
|
406
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
407
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
408
|
+
elif self.config.use_beta_sigmas:
|
409
|
+
sigmas = np.flip(sigmas).copy()
|
410
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
411
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
412
|
+
elif self.config.use_flow_sigmas:
|
413
|
+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
414
|
+
sigmas = 1.0 - alphas
|
415
|
+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
416
|
+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
381
417
|
else:
|
382
418
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
383
419
|
|
@@ -466,8 +502,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
466
502
|
return t
|
467
503
|
|
468
504
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
469
|
-
|
470
|
-
|
505
|
+
if self.config.use_flow_sigmas:
|
506
|
+
alpha_t = 1 - sigma
|
507
|
+
sigma_t = sigma
|
508
|
+
else:
|
509
|
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
510
|
+
sigma_t = sigma * alpha_t
|
471
511
|
|
472
512
|
return alpha_t, sigma_t
|
473
513
|
|
@@ -510,6 +550,60 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
510
550
|
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
511
551
|
return lambdas
|
512
552
|
|
553
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
554
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
555
|
+
"""Constructs an exponential noise schedule."""
|
556
|
+
|
557
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
558
|
+
# TODO: Add this logic to the other schedulers
|
559
|
+
if hasattr(self.config, "sigma_min"):
|
560
|
+
sigma_min = self.config.sigma_min
|
561
|
+
else:
|
562
|
+
sigma_min = None
|
563
|
+
|
564
|
+
if hasattr(self.config, "sigma_max"):
|
565
|
+
sigma_max = self.config.sigma_max
|
566
|
+
else:
|
567
|
+
sigma_max = None
|
568
|
+
|
569
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
570
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
571
|
+
|
572
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
573
|
+
return sigmas
|
574
|
+
|
575
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
576
|
+
def _convert_to_beta(
|
577
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
578
|
+
) -> torch.Tensor:
|
579
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
580
|
+
|
581
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
582
|
+
# TODO: Add this logic to the other schedulers
|
583
|
+
if hasattr(self.config, "sigma_min"):
|
584
|
+
sigma_min = self.config.sigma_min
|
585
|
+
else:
|
586
|
+
sigma_min = None
|
587
|
+
|
588
|
+
if hasattr(self.config, "sigma_max"):
|
589
|
+
sigma_max = self.config.sigma_max
|
590
|
+
else:
|
591
|
+
sigma_max = None
|
592
|
+
|
593
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
594
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
595
|
+
|
596
|
+
sigmas = np.array(
|
597
|
+
[
|
598
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
599
|
+
for ppf in [
|
600
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
601
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
602
|
+
]
|
603
|
+
]
|
604
|
+
)
|
605
|
+
return sigmas
|
606
|
+
|
513
607
|
def convert_model_output(
|
514
608
|
self,
|
515
609
|
model_output: torch.Tensor,
|
@@ -567,10 +661,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
567
661
|
sigma = self.sigmas[self.step_index]
|
568
662
|
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
569
663
|
x0_pred = alpha_t * sample - sigma_t * model_output
|
664
|
+
elif self.config.prediction_type == "flow_prediction":
|
665
|
+
sigma_t = self.sigmas[self.step_index]
|
666
|
+
x0_pred = sample - sigma_t * model_output
|
570
667
|
else:
|
571
668
|
raise ValueError(
|
572
|
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,
|
573
|
-
"
|
669
|
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
670
|
+
"`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
|
574
671
|
)
|
575
672
|
|
576
673
|
if self.config.thresholding:
|
@@ -806,6 +903,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
806
903
|
model_output_list: List[torch.Tensor],
|
807
904
|
*args,
|
808
905
|
sample: torch.Tensor = None,
|
906
|
+
noise: Optional[torch.Tensor] = None,
|
809
907
|
**kwargs,
|
810
908
|
) -> torch.Tensor:
|
811
909
|
"""
|
@@ -884,6 +982,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
884
982
|
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
885
983
|
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
886
984
|
)
|
985
|
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
986
|
+
assert noise is not None
|
987
|
+
x_t = (
|
988
|
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
989
|
+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
|
990
|
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
991
|
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
|
992
|
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
993
|
+
)
|
887
994
|
return x_t
|
888
995
|
|
889
996
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
@@ -990,7 +1097,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
990
1097
|
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
991
1098
|
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
992
1099
|
else:
|
993
|
-
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
1100
|
+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
|
994
1101
|
|
995
1102
|
if self.lower_order_nums < self.config.solver_order:
|
996
1103
|
self.lower_order_nums += 1
|
@@ -21,11 +21,15 @@ import numpy as np
|
|
21
21
|
import torch
|
22
22
|
|
23
23
|
from ..configuration_utils import ConfigMixin, register_to_config
|
24
|
-
from ..utils import deprecate
|
24
|
+
from ..utils import deprecate, is_scipy_available
|
25
25
|
from ..utils.torch_utils import randn_tensor
|
26
26
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
27
27
|
|
28
28
|
|
29
|
+
if is_scipy_available():
|
30
|
+
import scipy.stats
|
31
|
+
|
32
|
+
|
29
33
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
30
34
|
def betas_for_alpha_bar(
|
31
35
|
num_diffusion_timesteps,
|
@@ -124,6 +128,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
124
128
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
125
129
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
126
130
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
131
|
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
132
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
133
|
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
134
|
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
135
|
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
127
136
|
lambda_min_clipped (`float`, defaults to `-inf`):
|
128
137
|
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
129
138
|
cosine (`squaredcos_cap_v2`) noise schedule.
|
@@ -158,11 +167,21 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
158
167
|
lower_order_final: bool = True,
|
159
168
|
euler_at_final: bool = False,
|
160
169
|
use_karras_sigmas: Optional[bool] = False,
|
170
|
+
use_exponential_sigmas: Optional[bool] = False,
|
171
|
+
use_beta_sigmas: Optional[bool] = False,
|
172
|
+
use_flow_sigmas: Optional[bool] = False,
|
173
|
+
flow_shift: Optional[float] = 1.0,
|
161
174
|
lambda_min_clipped: float = -float("inf"),
|
162
175
|
variance_type: Optional[str] = None,
|
163
176
|
timestep_spacing: str = "linspace",
|
164
177
|
steps_offset: int = 0,
|
165
178
|
):
|
179
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
180
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
181
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
182
|
+
raise ValueError(
|
183
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
184
|
+
)
|
166
185
|
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
167
186
|
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
168
187
|
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
@@ -213,6 +232,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
213
232
|
self._step_index = None
|
214
233
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
215
234
|
self.use_karras_sigmas = use_karras_sigmas
|
235
|
+
self.use_exponential_sigmas = use_exponential_sigmas
|
236
|
+
self.use_beta_sigmas = use_beta_sigmas
|
216
237
|
|
217
238
|
@property
|
218
239
|
def step_index(self):
|
@@ -267,6 +288,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
267
288
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
268
289
|
timesteps = timesteps.copy().astype(np.int64)
|
269
290
|
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
291
|
+
elif self.config.use_exponential_sigmas:
|
292
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
293
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
294
|
+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
295
|
+
elif self.config.use_beta_sigmas:
|
296
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
297
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
298
|
+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
299
|
+
elif self.config.use_flow_sigmas:
|
300
|
+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
301
|
+
sigmas = 1.0 - alphas
|
302
|
+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
303
|
+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
304
|
+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
270
305
|
else:
|
271
306
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
272
307
|
sigma_max = (
|
@@ -354,8 +389,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
354
389
|
|
355
390
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
356
391
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
357
|
-
|
358
|
-
|
392
|
+
if self.config.use_flow_sigmas:
|
393
|
+
alpha_t = 1 - sigma
|
394
|
+
sigma_t = sigma
|
395
|
+
else:
|
396
|
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
397
|
+
sigma_t = sigma * alpha_t
|
359
398
|
|
360
399
|
return alpha_t, sigma_t
|
361
400
|
|
@@ -385,6 +424,60 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
385
424
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
386
425
|
return sigmas
|
387
426
|
|
427
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
428
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
429
|
+
"""Constructs an exponential noise schedule."""
|
430
|
+
|
431
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
432
|
+
# TODO: Add this logic to the other schedulers
|
433
|
+
if hasattr(self.config, "sigma_min"):
|
434
|
+
sigma_min = self.config.sigma_min
|
435
|
+
else:
|
436
|
+
sigma_min = None
|
437
|
+
|
438
|
+
if hasattr(self.config, "sigma_max"):
|
439
|
+
sigma_max = self.config.sigma_max
|
440
|
+
else:
|
441
|
+
sigma_max = None
|
442
|
+
|
443
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
444
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
445
|
+
|
446
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
447
|
+
return sigmas
|
448
|
+
|
449
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
450
|
+
def _convert_to_beta(
|
451
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
452
|
+
) -> torch.Tensor:
|
453
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
454
|
+
|
455
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
456
|
+
# TODO: Add this logic to the other schedulers
|
457
|
+
if hasattr(self.config, "sigma_min"):
|
458
|
+
sigma_min = self.config.sigma_min
|
459
|
+
else:
|
460
|
+
sigma_min = None
|
461
|
+
|
462
|
+
if hasattr(self.config, "sigma_max"):
|
463
|
+
sigma_max = self.config.sigma_max
|
464
|
+
else:
|
465
|
+
sigma_max = None
|
466
|
+
|
467
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
468
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
469
|
+
|
470
|
+
sigmas = np.array(
|
471
|
+
[
|
472
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
473
|
+
for ppf in [
|
474
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
475
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
476
|
+
]
|
477
|
+
]
|
478
|
+
)
|
479
|
+
return sigmas
|
480
|
+
|
388
481
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
|
389
482
|
def convert_model_output(
|
390
483
|
self,
|
@@ -443,10 +536,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
443
536
|
sigma = self.sigmas[self.step_index]
|
444
537
|
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
445
538
|
x0_pred = alpha_t * sample - sigma_t * model_output
|
539
|
+
elif self.config.prediction_type == "flow_prediction":
|
540
|
+
sigma_t = self.sigmas[self.step_index]
|
541
|
+
x0_pred = sample - sigma_t * model_output
|
446
542
|
else:
|
447
543
|
raise ValueError(
|
448
|
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,
|
449
|
-
"
|
544
|
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
545
|
+
"`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
|
450
546
|
)
|
451
547
|
|
452
548
|
if self.config.thresholding:
|
@@ -685,6 +781,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
685
781
|
model_output_list: List[torch.Tensor],
|
686
782
|
*args,
|
687
783
|
sample: torch.Tensor = None,
|
784
|
+
noise: Optional[torch.Tensor] = None,
|
688
785
|
**kwargs,
|
689
786
|
) -> torch.Tensor:
|
690
787
|
"""
|
@@ -763,6 +860,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
763
860
|
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
764
861
|
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
765
862
|
)
|
863
|
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
864
|
+
assert noise is not None
|
865
|
+
x_t = (
|
866
|
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
867
|
+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
|
868
|
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
869
|
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
|
870
|
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
871
|
+
)
|
766
872
|
return x_t
|
767
873
|
|
768
874
|
def _init_step_index(self, timestep):
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
+
from dataclasses import dataclass
|
16
17
|
from typing import List, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import numpy as np
|
@@ -20,7 +21,31 @@ import torch
|
|
20
21
|
import torchsde
|
21
22
|
|
22
23
|
from ..configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from
|
24
|
+
from ..utils import BaseOutput, is_scipy_available
|
25
|
+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
26
|
+
|
27
|
+
|
28
|
+
if is_scipy_available():
|
29
|
+
import scipy.stats
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
|
34
|
+
class DPMSolverSDESchedulerOutput(BaseOutput):
|
35
|
+
"""
|
36
|
+
Output class for the scheduler's `step` function output.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
40
|
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
41
|
+
denoising loop.
|
42
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
43
|
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
44
|
+
`pred_original_sample` can be used to preview progress or for guidance.
|
45
|
+
"""
|
46
|
+
|
47
|
+
prev_sample: torch.Tensor
|
48
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
24
49
|
|
25
50
|
|
26
51
|
class BatchedBrownianTree:
|
@@ -38,7 +63,20 @@ class BatchedBrownianTree:
|
|
38
63
|
except TypeError:
|
39
64
|
seed = [seed]
|
40
65
|
self.batched = False
|
41
|
-
self.trees = [
|
66
|
+
self.trees = [
|
67
|
+
torchsde.BrownianInterval(
|
68
|
+
t0=t0,
|
69
|
+
t1=t1,
|
70
|
+
size=w0.shape,
|
71
|
+
dtype=w0.dtype,
|
72
|
+
device=w0.device,
|
73
|
+
entropy=s,
|
74
|
+
tol=1e-6,
|
75
|
+
pool_size=24,
|
76
|
+
halfway_tree=True,
|
77
|
+
)
|
78
|
+
for s in seed
|
79
|
+
]
|
42
80
|
|
43
81
|
@staticmethod
|
44
82
|
def sort(a, b):
|
@@ -147,6 +185,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
147
185
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
148
186
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
149
187
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
188
|
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
189
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
190
|
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
191
|
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
192
|
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
150
193
|
noise_sampler_seed (`int`, *optional*, defaults to `None`):
|
151
194
|
The random seed to use for the noise sampler. If `None`, a random seed is generated.
|
152
195
|
timestep_spacing (`str`, defaults to `"linspace"`):
|
@@ -169,10 +212,18 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
169
212
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
170
213
|
prediction_type: str = "epsilon",
|
171
214
|
use_karras_sigmas: Optional[bool] = False,
|
215
|
+
use_exponential_sigmas: Optional[bool] = False,
|
216
|
+
use_beta_sigmas: Optional[bool] = False,
|
172
217
|
noise_sampler_seed: Optional[int] = None,
|
173
218
|
timestep_spacing: str = "linspace",
|
174
219
|
steps_offset: int = 0,
|
175
220
|
):
|
221
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
222
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
223
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
224
|
+
raise ValueError(
|
225
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
226
|
+
)
|
176
227
|
if trained_betas is not None:
|
177
228
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
178
229
|
elif beta_schedule == "linear":
|
@@ -328,6 +379,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
328
379
|
if self.config.use_karras_sigmas:
|
329
380
|
sigmas = self._convert_to_karras(in_sigmas=sigmas)
|
330
381
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
382
|
+
elif self.config.use_exponential_sigmas:
|
383
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
384
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
385
|
+
elif self.config.use_beta_sigmas:
|
386
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
387
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
331
388
|
|
332
389
|
second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
|
333
390
|
|
@@ -408,6 +465,60 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
408
465
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
409
466
|
return sigmas
|
410
467
|
|
468
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
469
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
470
|
+
"""Constructs an exponential noise schedule."""
|
471
|
+
|
472
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
473
|
+
# TODO: Add this logic to the other schedulers
|
474
|
+
if hasattr(self.config, "sigma_min"):
|
475
|
+
sigma_min = self.config.sigma_min
|
476
|
+
else:
|
477
|
+
sigma_min = None
|
478
|
+
|
479
|
+
if hasattr(self.config, "sigma_max"):
|
480
|
+
sigma_max = self.config.sigma_max
|
481
|
+
else:
|
482
|
+
sigma_max = None
|
483
|
+
|
484
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
485
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
486
|
+
|
487
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
488
|
+
return sigmas
|
489
|
+
|
490
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
491
|
+
def _convert_to_beta(
|
492
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
493
|
+
) -> torch.Tensor:
|
494
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
495
|
+
|
496
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
497
|
+
# TODO: Add this logic to the other schedulers
|
498
|
+
if hasattr(self.config, "sigma_min"):
|
499
|
+
sigma_min = self.config.sigma_min
|
500
|
+
else:
|
501
|
+
sigma_min = None
|
502
|
+
|
503
|
+
if hasattr(self.config, "sigma_max"):
|
504
|
+
sigma_max = self.config.sigma_max
|
505
|
+
else:
|
506
|
+
sigma_max = None
|
507
|
+
|
508
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
509
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
510
|
+
|
511
|
+
sigmas = np.array(
|
512
|
+
[
|
513
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
514
|
+
for ppf in [
|
515
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
516
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
517
|
+
]
|
518
|
+
]
|
519
|
+
)
|
520
|
+
return sigmas
|
521
|
+
|
411
522
|
@property
|
412
523
|
def state_in_first_order(self):
|
413
524
|
return self.sample is None
|
@@ -419,7 +530,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
419
530
|
sample: Union[torch.Tensor, np.ndarray],
|
420
531
|
return_dict: bool = True,
|
421
532
|
s_noise: float = 1.0,
|
422
|
-
) -> Union[
|
533
|
+
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
|
423
534
|
"""
|
424
535
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
425
536
|
process from the learned model outputs (most often the predicted noise).
|
@@ -431,15 +542,16 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
431
542
|
The current discrete timestep in the diffusion chain.
|
432
543
|
sample (`torch.Tensor` or `np.ndarray`):
|
433
544
|
A current instance of a sample created by the diffusion process.
|
434
|
-
return_dict (`bool
|
435
|
-
Whether or not to return a [`~schedulers.
|
545
|
+
return_dict (`bool`):
|
546
|
+
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
|
547
|
+
tuple.
|
436
548
|
s_noise (`float`, *optional*, defaults to 1.0):
|
437
549
|
Scaling factor for noise added to the sample.
|
438
550
|
|
439
551
|
Returns:
|
440
|
-
[`~schedulers.
|
441
|
-
If return_dict is `True`, [`~schedulers.
|
442
|
-
tuple is returned where the first element is the sample tensor.
|
552
|
+
[`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
|
553
|
+
If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
|
554
|
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
443
555
|
"""
|
444
556
|
if self.step_index is None:
|
445
557
|
self._init_step_index(timestep)
|
@@ -519,9 +631,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
519
631
|
self._step_index += 1
|
520
632
|
|
521
633
|
if not return_dict:
|
522
|
-
return (
|
634
|
+
return (
|
635
|
+
prev_sample,
|
636
|
+
pred_original_sample,
|
637
|
+
)
|
523
638
|
|
524
|
-
return
|
639
|
+
return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
525
640
|
|
526
641
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
527
642
|
def add_noise(
|