diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -21,11 +21,14 @@ import numpy as np
|
|
21
21
|
import torch
|
22
22
|
|
23
23
|
from ..configuration_utils import ConfigMixin, register_to_config
|
24
|
-
from ..utils import deprecate, logging
|
24
|
+
from ..utils import deprecate, is_scipy_available, logging
|
25
25
|
from ..utils.torch_utils import randn_tensor
|
26
26
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
27
27
|
|
28
28
|
|
29
|
+
if is_scipy_available():
|
30
|
+
import scipy.stats
|
31
|
+
|
29
32
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30
33
|
|
31
34
|
|
@@ -123,6 +126,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
123
126
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
124
127
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
125
128
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
129
|
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
130
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
131
|
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
132
|
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
133
|
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
126
134
|
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
|
127
135
|
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
128
136
|
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
@@ -154,10 +162,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
154
162
|
solver_type: str = "midpoint",
|
155
163
|
lower_order_final: bool = False,
|
156
164
|
use_karras_sigmas: Optional[bool] = False,
|
165
|
+
use_exponential_sigmas: Optional[bool] = False,
|
166
|
+
use_beta_sigmas: Optional[bool] = False,
|
167
|
+
use_flow_sigmas: Optional[bool] = False,
|
168
|
+
flow_shift: Optional[float] = 1.0,
|
157
169
|
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
158
170
|
lambda_min_clipped: float = -float("inf"),
|
159
171
|
variance_type: Optional[str] = None,
|
160
172
|
):
|
173
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
174
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
175
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
176
|
+
raise ValueError(
|
177
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
178
|
+
)
|
161
179
|
if algorithm_type == "dpmsolver":
|
162
180
|
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
163
181
|
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
|
@@ -248,6 +266,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
248
266
|
orders = [1, 2] * (steps // 2)
|
249
267
|
elif order == 1:
|
250
268
|
orders = [1] * steps
|
269
|
+
|
270
|
+
if self.config.final_sigmas_type == "zero":
|
271
|
+
orders[-1] = 1
|
272
|
+
|
251
273
|
return orders
|
252
274
|
|
253
275
|
@property
|
@@ -300,6 +322,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
300
322
|
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
301
323
|
if timesteps is not None and self.config.use_karras_sigmas:
|
302
324
|
raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
|
325
|
+
if timesteps is not None and self.config.use_exponential_sigmas:
|
326
|
+
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
327
|
+
if timesteps is not None and self.config.use_beta_sigmas:
|
328
|
+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
303
329
|
|
304
330
|
num_inference_steps = num_inference_steps or len(timesteps)
|
305
331
|
self.num_inference_steps = num_inference_steps
|
@@ -310,6 +336,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
310
336
|
# Clipping the minimum of all lambda(t) for numerical stability.
|
311
337
|
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
312
338
|
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
|
339
|
+
clipped_idx = clipped_idx.item()
|
313
340
|
timesteps = (
|
314
341
|
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
|
315
342
|
.round()[::-1][:-1]
|
@@ -318,11 +345,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
318
345
|
)
|
319
346
|
|
320
347
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
348
|
+
log_sigmas = np.log(sigmas)
|
321
349
|
if self.config.use_karras_sigmas:
|
322
|
-
log_sigmas = np.log(sigmas)
|
323
350
|
sigmas = np.flip(sigmas).copy()
|
324
351
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
325
352
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
353
|
+
elif self.config.use_exponential_sigmas:
|
354
|
+
sigmas = np.flip(sigmas).copy()
|
355
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
356
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
357
|
+
elif self.config.use_beta_sigmas:
|
358
|
+
sigmas = np.flip(sigmas).copy()
|
359
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
360
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
361
|
+
elif self.config.use_flow_sigmas:
|
362
|
+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
363
|
+
sigmas = 1.0 - alphas
|
364
|
+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
365
|
+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
326
366
|
else:
|
327
367
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
328
368
|
|
@@ -421,8 +461,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
421
461
|
|
422
462
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
423
463
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
424
|
-
|
425
|
-
|
464
|
+
if self.config.use_flow_sigmas:
|
465
|
+
alpha_t = 1 - sigma
|
466
|
+
sigma_t = sigma
|
467
|
+
else:
|
468
|
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
469
|
+
sigma_t = sigma * alpha_t
|
426
470
|
|
427
471
|
return alpha_t, sigma_t
|
428
472
|
|
@@ -452,6 +496,60 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
452
496
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
453
497
|
return sigmas
|
454
498
|
|
499
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
500
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
501
|
+
"""Constructs an exponential noise schedule."""
|
502
|
+
|
503
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
504
|
+
# TODO: Add this logic to the other schedulers
|
505
|
+
if hasattr(self.config, "sigma_min"):
|
506
|
+
sigma_min = self.config.sigma_min
|
507
|
+
else:
|
508
|
+
sigma_min = None
|
509
|
+
|
510
|
+
if hasattr(self.config, "sigma_max"):
|
511
|
+
sigma_max = self.config.sigma_max
|
512
|
+
else:
|
513
|
+
sigma_max = None
|
514
|
+
|
515
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
516
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
517
|
+
|
518
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
519
|
+
return sigmas
|
520
|
+
|
521
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
522
|
+
def _convert_to_beta(
|
523
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
524
|
+
) -> torch.Tensor:
|
525
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
526
|
+
|
527
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
528
|
+
# TODO: Add this logic to the other schedulers
|
529
|
+
if hasattr(self.config, "sigma_min"):
|
530
|
+
sigma_min = self.config.sigma_min
|
531
|
+
else:
|
532
|
+
sigma_min = None
|
533
|
+
|
534
|
+
if hasattr(self.config, "sigma_max"):
|
535
|
+
sigma_max = self.config.sigma_max
|
536
|
+
else:
|
537
|
+
sigma_max = None
|
538
|
+
|
539
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
540
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
541
|
+
|
542
|
+
sigmas = np.array(
|
543
|
+
[
|
544
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
545
|
+
for ppf in [
|
546
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
547
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
548
|
+
]
|
549
|
+
]
|
550
|
+
)
|
551
|
+
return sigmas
|
552
|
+
|
455
553
|
def convert_model_output(
|
456
554
|
self,
|
457
555
|
model_output: torch.Tensor,
|
@@ -508,10 +606,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
508
606
|
sigma = self.sigmas[self.step_index]
|
509
607
|
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
510
608
|
x0_pred = alpha_t * sample - sigma_t * model_output
|
609
|
+
elif self.config.prediction_type == "flow_prediction":
|
610
|
+
sigma_t = self.sigmas[self.step_index]
|
611
|
+
x0_pred = sample - sigma_t * model_output
|
511
612
|
else:
|
512
613
|
raise ValueError(
|
513
|
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,
|
514
|
-
"
|
614
|
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
615
|
+
"`v_prediction`, or `flow_prediction` for the DPMSolverSinglestepScheduler."
|
515
616
|
)
|
516
617
|
|
517
618
|
if self.config.thresholding:
|
@@ -729,6 +830,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
729
830
|
model_output_list: List[torch.Tensor],
|
730
831
|
*args,
|
731
832
|
sample: torch.Tensor = None,
|
833
|
+
noise: Optional[torch.Tensor] = None,
|
732
834
|
**kwargs,
|
733
835
|
) -> torch.Tensor:
|
734
836
|
"""
|
@@ -826,6 +928,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
826
928
|
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
827
929
|
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
828
930
|
)
|
931
|
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
932
|
+
assert noise is not None
|
933
|
+
if self.config.solver_type == "midpoint":
|
934
|
+
x_t = (
|
935
|
+
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
|
936
|
+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
|
937
|
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
|
938
|
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
939
|
+
)
|
940
|
+
elif self.config.solver_type == "heun":
|
941
|
+
x_t = (
|
942
|
+
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
|
943
|
+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
|
944
|
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
945
|
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h) ** 2 - 0.5)) * D2
|
946
|
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
947
|
+
)
|
829
948
|
return x_t
|
830
949
|
|
831
950
|
def singlestep_dpm_solver_update(
|
@@ -887,7 +1006,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
887
1006
|
elif order == 2:
|
888
1007
|
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
|
889
1008
|
elif order == 3:
|
890
|
-
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
|
1009
|
+
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
|
891
1010
|
else:
|
892
1011
|
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
893
1012
|
|
@@ -333,14 +333,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
333
333
|
|
334
334
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
335
335
|
|
336
|
-
noise = randn_tensor(
|
337
|
-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
338
|
-
)
|
339
|
-
|
340
|
-
eps = noise * s_noise
|
341
336
|
sigma_hat = sigma * (gamma + 1)
|
342
337
|
|
343
338
|
if gamma > 0:
|
339
|
+
noise = randn_tensor(
|
340
|
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
341
|
+
)
|
342
|
+
eps = noise * s_noise
|
344
343
|
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
345
344
|
|
346
345
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
@@ -360,7 +359,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
360
359
|
self._step_index += 1
|
361
360
|
|
362
361
|
if not return_dict:
|
363
|
-
return (
|
362
|
+
return (
|
363
|
+
prev_sample,
|
364
|
+
pred_original_sample,
|
365
|
+
)
|
364
366
|
|
365
367
|
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
366
368
|
|
@@ -435,7 +435,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
435
435
|
self._step_index += 1
|
436
436
|
|
437
437
|
if not return_dict:
|
438
|
-
return (
|
438
|
+
return (
|
439
|
+
prev_sample,
|
440
|
+
pred_original_sample,
|
441
|
+
)
|
439
442
|
|
440
443
|
return EulerAncestralDiscreteSchedulerOutput(
|
441
444
|
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
@@ -20,11 +20,14 @@ import numpy as np
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from ..configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ..utils import BaseOutput, logging
|
23
|
+
from ..utils import BaseOutput, is_scipy_available, logging
|
24
24
|
from ..utils.torch_utils import randn_tensor
|
25
25
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
26
26
|
|
27
27
|
|
28
|
+
if is_scipy_available():
|
29
|
+
import scipy.stats
|
30
|
+
|
28
31
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29
32
|
|
30
33
|
|
@@ -158,6 +161,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
158
161
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
159
162
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
160
163
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
164
|
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
165
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
166
|
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
167
|
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
168
|
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
161
169
|
timestep_spacing (`str`, defaults to `"linspace"`):
|
162
170
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
163
171
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
@@ -186,6 +194,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
186
194
|
prediction_type: str = "epsilon",
|
187
195
|
interpolation_type: str = "linear",
|
188
196
|
use_karras_sigmas: Optional[bool] = False,
|
197
|
+
use_exponential_sigmas: Optional[bool] = False,
|
198
|
+
use_beta_sigmas: Optional[bool] = False,
|
189
199
|
sigma_min: Optional[float] = None,
|
190
200
|
sigma_max: Optional[float] = None,
|
191
201
|
timestep_spacing: str = "linspace",
|
@@ -194,6 +204,12 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
194
204
|
rescale_betas_zero_snr: bool = False,
|
195
205
|
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
196
206
|
):
|
207
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
208
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
209
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
210
|
+
raise ValueError(
|
211
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
212
|
+
)
|
197
213
|
if trained_betas is not None:
|
198
214
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
199
215
|
elif beta_schedule == "linear":
|
@@ -235,6 +251,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
235
251
|
|
236
252
|
self.is_scale_input_called = False
|
237
253
|
self.use_karras_sigmas = use_karras_sigmas
|
254
|
+
self.use_exponential_sigmas = use_exponential_sigmas
|
255
|
+
self.use_beta_sigmas = use_beta_sigmas
|
238
256
|
|
239
257
|
self._step_index = None
|
240
258
|
self._begin_index = None
|
@@ -332,6 +350,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
332
350
|
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
|
333
351
|
if timesteps is not None and self.config.use_karras_sigmas:
|
334
352
|
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
|
353
|
+
if timesteps is not None and self.config.use_exponential_sigmas:
|
354
|
+
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
355
|
+
if timesteps is not None and self.config.use_beta_sigmas:
|
356
|
+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
335
357
|
if (
|
336
358
|
timesteps is not None
|
337
359
|
and self.config.timestep_type == "continuous"
|
@@ -396,6 +418,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
396
418
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
397
419
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
398
420
|
|
421
|
+
elif self.config.use_exponential_sigmas:
|
422
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
423
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
424
|
+
|
425
|
+
elif self.config.use_beta_sigmas:
|
426
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
427
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
428
|
+
|
399
429
|
if self.config.final_sigmas_type == "sigma_min":
|
400
430
|
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
401
431
|
elif self.config.final_sigmas_type == "zero":
|
@@ -468,6 +498,59 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
468
498
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
469
499
|
return sigmas
|
470
500
|
|
501
|
+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
|
502
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
503
|
+
"""Constructs an exponential noise schedule."""
|
504
|
+
|
505
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
506
|
+
# TODO: Add this logic to the other schedulers
|
507
|
+
if hasattr(self.config, "sigma_min"):
|
508
|
+
sigma_min = self.config.sigma_min
|
509
|
+
else:
|
510
|
+
sigma_min = None
|
511
|
+
|
512
|
+
if hasattr(self.config, "sigma_max"):
|
513
|
+
sigma_max = self.config.sigma_max
|
514
|
+
else:
|
515
|
+
sigma_max = None
|
516
|
+
|
517
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
518
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
519
|
+
|
520
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
521
|
+
return sigmas
|
522
|
+
|
523
|
+
def _convert_to_beta(
|
524
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
525
|
+
) -> torch.Tensor:
|
526
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
527
|
+
|
528
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
529
|
+
# TODO: Add this logic to the other schedulers
|
530
|
+
if hasattr(self.config, "sigma_min"):
|
531
|
+
sigma_min = self.config.sigma_min
|
532
|
+
else:
|
533
|
+
sigma_min = None
|
534
|
+
|
535
|
+
if hasattr(self.config, "sigma_max"):
|
536
|
+
sigma_max = self.config.sigma_max
|
537
|
+
else:
|
538
|
+
sigma_max = None
|
539
|
+
|
540
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
541
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
542
|
+
|
543
|
+
sigmas = np.array(
|
544
|
+
[
|
545
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
546
|
+
for ppf in [
|
547
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
548
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
549
|
+
]
|
550
|
+
]
|
551
|
+
)
|
552
|
+
return sigmas
|
553
|
+
|
471
554
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
472
555
|
if schedule_timesteps is None:
|
473
556
|
schedule_timesteps = self.timesteps
|
@@ -555,14 +638,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
555
638
|
|
556
639
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
557
640
|
|
558
|
-
noise = randn_tensor(
|
559
|
-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
560
|
-
)
|
561
|
-
|
562
|
-
eps = noise * s_noise
|
563
641
|
sigma_hat = sigma * (gamma + 1)
|
564
642
|
|
565
643
|
if gamma > 0:
|
644
|
+
noise = randn_tensor(
|
645
|
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
646
|
+
)
|
647
|
+
eps = noise * s_noise
|
566
648
|
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
567
649
|
|
568
650
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
@@ -594,7 +676,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
594
676
|
self._step_index += 1
|
595
677
|
|
596
678
|
if not return_dict:
|
597
|
-
return (
|
679
|
+
return (
|
680
|
+
prev_sample,
|
681
|
+
pred_original_sample,
|
682
|
+
)
|
598
683
|
|
599
684
|
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
600
685
|
|
@@ -20,10 +20,13 @@ import numpy as np
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from ..configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ..utils import BaseOutput, logging
|
23
|
+
from ..utils import BaseOutput, is_scipy_available, logging
|
24
24
|
from .scheduling_utils import SchedulerMixin
|
25
25
|
|
26
26
|
|
27
|
+
if is_scipy_available():
|
28
|
+
import scipy.stats
|
29
|
+
|
27
30
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28
31
|
|
29
32
|
|
@@ -71,7 +74,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
71
74
|
max_shift: Optional[float] = 1.15,
|
72
75
|
base_image_seq_len: Optional[int] = 256,
|
73
76
|
max_image_seq_len: Optional[int] = 4096,
|
77
|
+
invert_sigmas: bool = False,
|
78
|
+
shift_terminal: Optional[float] = None,
|
79
|
+
use_karras_sigmas: Optional[bool] = False,
|
80
|
+
use_exponential_sigmas: Optional[bool] = False,
|
81
|
+
use_beta_sigmas: Optional[bool] = False,
|
74
82
|
):
|
83
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
84
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
85
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
86
|
+
raise ValueError(
|
87
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
88
|
+
)
|
75
89
|
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
76
90
|
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
77
91
|
|
@@ -85,10 +99,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
85
99
|
self._step_index = None
|
86
100
|
self._begin_index = None
|
87
101
|
|
102
|
+
self._shift = shift
|
103
|
+
|
88
104
|
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
89
105
|
self.sigma_min = self.sigmas[-1].item()
|
90
106
|
self.sigma_max = self.sigmas[0].item()
|
91
107
|
|
108
|
+
@property
|
109
|
+
def shift(self):
|
110
|
+
"""
|
111
|
+
The value used for shifting.
|
112
|
+
"""
|
113
|
+
return self._shift
|
114
|
+
|
92
115
|
@property
|
93
116
|
def step_index(self):
|
94
117
|
"""
|
@@ -114,6 +137,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
114
137
|
"""
|
115
138
|
self._begin_index = begin_index
|
116
139
|
|
140
|
+
def set_shift(self, shift: float):
|
141
|
+
self._shift = shift
|
142
|
+
|
117
143
|
def scale_noise(
|
118
144
|
self,
|
119
145
|
sample: torch.FloatTensor,
|
@@ -168,6 +194,27 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
168
194
|
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
169
195
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
170
196
|
|
197
|
+
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
198
|
+
r"""
|
199
|
+
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
200
|
+
value.
|
201
|
+
|
202
|
+
Reference:
|
203
|
+
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
204
|
+
|
205
|
+
Args:
|
206
|
+
t (`torch.Tensor`):
|
207
|
+
A tensor of timesteps to be stretched and shifted.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
`torch.Tensor`:
|
211
|
+
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
212
|
+
"""
|
213
|
+
one_minus_z = 1 - t
|
214
|
+
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
215
|
+
stretched_t = 1 - (one_minus_z / scale_factor)
|
216
|
+
return stretched_t
|
217
|
+
|
171
218
|
def set_timesteps(
|
172
219
|
self,
|
173
220
|
num_inference_steps: int = None,
|
@@ -184,29 +231,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
184
231
|
device (`str` or `torch.device`, *optional*):
|
185
232
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
186
233
|
"""
|
187
|
-
|
188
234
|
if self.config.use_dynamic_shifting and mu is None:
|
189
235
|
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
190
236
|
|
191
237
|
if sigmas is None:
|
192
|
-
self.num_inference_steps = num_inference_steps
|
193
238
|
timesteps = np.linspace(
|
194
239
|
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
195
240
|
)
|
196
241
|
|
197
242
|
sigmas = timesteps / self.config.num_train_timesteps
|
243
|
+
else:
|
244
|
+
sigmas = np.array(sigmas).astype(np.float32)
|
245
|
+
num_inference_steps = len(sigmas)
|
246
|
+
self.num_inference_steps = num_inference_steps
|
198
247
|
|
199
248
|
if self.config.use_dynamic_shifting:
|
200
249
|
sigmas = self.time_shift(mu, 1.0, sigmas)
|
201
250
|
else:
|
202
|
-
sigmas = self.
|
251
|
+
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
252
|
+
|
253
|
+
if self.config.shift_terminal:
|
254
|
+
sigmas = self.stretch_shift_to_terminal(sigmas)
|
255
|
+
|
256
|
+
if self.config.use_karras_sigmas:
|
257
|
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
258
|
+
|
259
|
+
elif self.config.use_exponential_sigmas:
|
260
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
261
|
+
|
262
|
+
elif self.config.use_beta_sigmas:
|
263
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
203
264
|
|
204
265
|
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
205
266
|
timesteps = sigmas * self.config.num_train_timesteps
|
206
267
|
|
207
|
-
self.
|
208
|
-
|
268
|
+
if self.config.invert_sigmas:
|
269
|
+
sigmas = 1.0 - sigmas
|
270
|
+
timesteps = sigmas * self.config.num_train_timesteps
|
271
|
+
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
272
|
+
else:
|
273
|
+
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
209
274
|
|
275
|
+
self.timesteps = timesteps.to(device=device)
|
276
|
+
self.sigmas = sigmas
|
210
277
|
self._step_index = None
|
211
278
|
self._begin_index = None
|
212
279
|
|
@@ -307,5 +374,85 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
307
374
|
|
308
375
|
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
309
376
|
|
377
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
378
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
379
|
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
380
|
+
|
381
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
382
|
+
# TODO: Add this logic to the other schedulers
|
383
|
+
if hasattr(self.config, "sigma_min"):
|
384
|
+
sigma_min = self.config.sigma_min
|
385
|
+
else:
|
386
|
+
sigma_min = None
|
387
|
+
|
388
|
+
if hasattr(self.config, "sigma_max"):
|
389
|
+
sigma_max = self.config.sigma_max
|
390
|
+
else:
|
391
|
+
sigma_max = None
|
392
|
+
|
393
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
394
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
395
|
+
|
396
|
+
rho = 7.0 # 7.0 is the value used in the paper
|
397
|
+
ramp = np.linspace(0, 1, num_inference_steps)
|
398
|
+
min_inv_rho = sigma_min ** (1 / rho)
|
399
|
+
max_inv_rho = sigma_max ** (1 / rho)
|
400
|
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
401
|
+
return sigmas
|
402
|
+
|
403
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
404
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
405
|
+
"""Constructs an exponential noise schedule."""
|
406
|
+
|
407
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
408
|
+
# TODO: Add this logic to the other schedulers
|
409
|
+
if hasattr(self.config, "sigma_min"):
|
410
|
+
sigma_min = self.config.sigma_min
|
411
|
+
else:
|
412
|
+
sigma_min = None
|
413
|
+
|
414
|
+
if hasattr(self.config, "sigma_max"):
|
415
|
+
sigma_max = self.config.sigma_max
|
416
|
+
else:
|
417
|
+
sigma_max = None
|
418
|
+
|
419
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
420
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
421
|
+
|
422
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
423
|
+
return sigmas
|
424
|
+
|
425
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
426
|
+
def _convert_to_beta(
|
427
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
428
|
+
) -> torch.Tensor:
|
429
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
430
|
+
|
431
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
432
|
+
# TODO: Add this logic to the other schedulers
|
433
|
+
if hasattr(self.config, "sigma_min"):
|
434
|
+
sigma_min = self.config.sigma_min
|
435
|
+
else:
|
436
|
+
sigma_min = None
|
437
|
+
|
438
|
+
if hasattr(self.config, "sigma_max"):
|
439
|
+
sigma_max = self.config.sigma_max
|
440
|
+
else:
|
441
|
+
sigma_max = None
|
442
|
+
|
443
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
444
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
445
|
+
|
446
|
+
sigmas = np.array(
|
447
|
+
[
|
448
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
449
|
+
for ppf in [
|
450
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
451
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
452
|
+
]
|
453
|
+
]
|
454
|
+
)
|
455
|
+
return sigmas
|
456
|
+
|
310
457
|
def __len__(self):
|
311
458
|
return self.config.num_train_timesteps
|