diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +38 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +238 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +40 -7
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +6 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,14 @@ import numpy as np
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from ..configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ..utils import BaseOutput, logging
|
23
|
+
from ..utils import BaseOutput, is_scipy_available, logging
|
24
24
|
from ..utils.torch_utils import randn_tensor
|
25
25
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
26
26
|
|
27
27
|
|
28
|
+
if is_scipy_available():
|
29
|
+
import scipy.stats
|
30
|
+
|
28
31
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29
32
|
|
30
33
|
|
@@ -158,6 +161,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
158
161
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
159
162
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
160
163
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
164
|
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
165
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
166
|
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
167
|
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
168
|
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
161
169
|
timestep_spacing (`str`, defaults to `"linspace"`):
|
162
170
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
163
171
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
@@ -186,6 +194,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
186
194
|
prediction_type: str = "epsilon",
|
187
195
|
interpolation_type: str = "linear",
|
188
196
|
use_karras_sigmas: Optional[bool] = False,
|
197
|
+
use_exponential_sigmas: Optional[bool] = False,
|
198
|
+
use_beta_sigmas: Optional[bool] = False,
|
189
199
|
sigma_min: Optional[float] = None,
|
190
200
|
sigma_max: Optional[float] = None,
|
191
201
|
timestep_spacing: str = "linspace",
|
@@ -194,6 +204,12 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
194
204
|
rescale_betas_zero_snr: bool = False,
|
195
205
|
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
196
206
|
):
|
207
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
208
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
209
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
210
|
+
raise ValueError(
|
211
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
212
|
+
)
|
197
213
|
if trained_betas is not None:
|
198
214
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
199
215
|
elif beta_schedule == "linear":
|
@@ -235,6 +251,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
235
251
|
|
236
252
|
self.is_scale_input_called = False
|
237
253
|
self.use_karras_sigmas = use_karras_sigmas
|
254
|
+
self.use_exponential_sigmas = use_exponential_sigmas
|
255
|
+
self.use_beta_sigmas = use_beta_sigmas
|
238
256
|
|
239
257
|
self._step_index = None
|
240
258
|
self._begin_index = None
|
@@ -332,6 +350,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
332
350
|
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
|
333
351
|
if timesteps is not None and self.config.use_karras_sigmas:
|
334
352
|
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
|
353
|
+
if timesteps is not None and self.config.use_exponential_sigmas:
|
354
|
+
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
355
|
+
if timesteps is not None and self.config.use_beta_sigmas:
|
356
|
+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
335
357
|
if (
|
336
358
|
timesteps is not None
|
337
359
|
and self.config.timestep_type == "continuous"
|
@@ -396,6 +418,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
396
418
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
397
419
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
398
420
|
|
421
|
+
elif self.config.use_exponential_sigmas:
|
422
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
423
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
424
|
+
|
425
|
+
elif self.config.use_beta_sigmas:
|
426
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
427
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
428
|
+
|
399
429
|
if self.config.final_sigmas_type == "sigma_min":
|
400
430
|
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
401
431
|
elif self.config.final_sigmas_type == "zero":
|
@@ -468,6 +498,59 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
468
498
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
469
499
|
return sigmas
|
470
500
|
|
501
|
+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
|
502
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
503
|
+
"""Constructs an exponential noise schedule."""
|
504
|
+
|
505
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
506
|
+
# TODO: Add this logic to the other schedulers
|
507
|
+
if hasattr(self.config, "sigma_min"):
|
508
|
+
sigma_min = self.config.sigma_min
|
509
|
+
else:
|
510
|
+
sigma_min = None
|
511
|
+
|
512
|
+
if hasattr(self.config, "sigma_max"):
|
513
|
+
sigma_max = self.config.sigma_max
|
514
|
+
else:
|
515
|
+
sigma_max = None
|
516
|
+
|
517
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
518
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
519
|
+
|
520
|
+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
521
|
+
return sigmas
|
522
|
+
|
523
|
+
def _convert_to_beta(
|
524
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
525
|
+
) -> torch.Tensor:
|
526
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
527
|
+
|
528
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
529
|
+
# TODO: Add this logic to the other schedulers
|
530
|
+
if hasattr(self.config, "sigma_min"):
|
531
|
+
sigma_min = self.config.sigma_min
|
532
|
+
else:
|
533
|
+
sigma_min = None
|
534
|
+
|
535
|
+
if hasattr(self.config, "sigma_max"):
|
536
|
+
sigma_max = self.config.sigma_max
|
537
|
+
else:
|
538
|
+
sigma_max = None
|
539
|
+
|
540
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
541
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
542
|
+
|
543
|
+
sigmas = torch.Tensor(
|
544
|
+
[
|
545
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
546
|
+
for ppf in [
|
547
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
548
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
549
|
+
]
|
550
|
+
]
|
551
|
+
)
|
552
|
+
return sigmas
|
553
|
+
|
471
554
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
472
555
|
if schedule_timesteps is None:
|
473
556
|
schedule_timesteps = self.timesteps
|
@@ -555,14 +638,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
555
638
|
|
556
639
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
557
640
|
|
558
|
-
noise = randn_tensor(
|
559
|
-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
560
|
-
)
|
561
|
-
|
562
|
-
eps = noise * s_noise
|
563
641
|
sigma_hat = sigma * (gamma + 1)
|
564
642
|
|
565
643
|
if gamma > 0:
|
644
|
+
noise = randn_tensor(
|
645
|
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
646
|
+
)
|
647
|
+
eps = noise * s_noise
|
566
648
|
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
567
649
|
|
568
650
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
@@ -594,7 +676,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
594
676
|
self._step_index += 1
|
595
677
|
|
596
678
|
if not return_dict:
|
597
|
-
return (
|
679
|
+
return (
|
680
|
+
prev_sample,
|
681
|
+
pred_original_sample,
|
682
|
+
)
|
598
683
|
|
599
684
|
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
600
685
|
|
@@ -266,14 +266,13 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
266
266
|
|
267
267
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
268
268
|
|
269
|
-
noise = randn_tensor(
|
270
|
-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
271
|
-
)
|
272
|
-
|
273
|
-
eps = noise * s_noise
|
274
269
|
sigma_hat = sigma * (gamma + 1)
|
275
270
|
|
276
271
|
if gamma > 0:
|
272
|
+
noise = randn_tensor(
|
273
|
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
274
|
+
)
|
275
|
+
eps = noise * s_noise
|
277
276
|
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
278
277
|
|
279
278
|
if self.state_in_first_order:
|
@@ -13,13 +13,38 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
+
from dataclasses import dataclass
|
16
17
|
from typing import List, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import numpy as np
|
19
20
|
import torch
|
20
21
|
|
21
22
|
from ..configuration_utils import ConfigMixin, register_to_config
|
22
|
-
from
|
23
|
+
from ..utils import BaseOutput, is_scipy_available
|
24
|
+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
25
|
+
|
26
|
+
|
27
|
+
if is_scipy_available():
|
28
|
+
import scipy.stats
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
|
33
|
+
class HeunDiscreteSchedulerOutput(BaseOutput):
|
34
|
+
"""
|
35
|
+
Output class for the scheduler's `step` function output.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39
|
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
40
|
+
denoising loop.
|
41
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
42
|
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
43
|
+
`pred_original_sample` can be used to preview progress or for guidance.
|
44
|
+
"""
|
45
|
+
|
46
|
+
prev_sample: torch.Tensor
|
47
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
23
48
|
|
24
49
|
|
25
50
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -97,6 +122,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
97
122
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
98
123
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
99
124
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
125
|
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
126
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
127
|
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
128
|
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
129
|
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
100
130
|
timestep_spacing (`str`, defaults to `"linspace"`):
|
101
131
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
102
132
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
@@ -117,11 +147,19 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
117
147
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
118
148
|
prediction_type: str = "epsilon",
|
119
149
|
use_karras_sigmas: Optional[bool] = False,
|
150
|
+
use_exponential_sigmas: Optional[bool] = False,
|
151
|
+
use_beta_sigmas: Optional[bool] = False,
|
120
152
|
clip_sample: Optional[bool] = False,
|
121
153
|
clip_sample_range: float = 1.0,
|
122
154
|
timestep_spacing: str = "linspace",
|
123
155
|
steps_offset: int = 0,
|
124
156
|
):
|
157
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
158
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
159
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
160
|
+
raise ValueError(
|
161
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
162
|
+
)
|
125
163
|
if trained_betas is not None:
|
126
164
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
127
165
|
elif beta_schedule == "linear":
|
@@ -251,6 +289,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
251
289
|
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
252
290
|
if timesteps is not None and self.config.use_karras_sigmas:
|
253
291
|
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
|
292
|
+
if timesteps is not None and self.config.use_exponential_sigmas:
|
293
|
+
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
294
|
+
if timesteps is not None and self.config.use_beta_sigmas:
|
295
|
+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
254
296
|
|
255
297
|
num_inference_steps = num_inference_steps or len(timesteps)
|
256
298
|
self.num_inference_steps = num_inference_steps
|
@@ -286,6 +328,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
286
328
|
if self.config.use_karras_sigmas:
|
287
329
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
288
330
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
331
|
+
elif self.config.use_exponential_sigmas:
|
332
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
333
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
334
|
+
elif self.config.use_beta_sigmas:
|
335
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
336
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
289
337
|
|
290
338
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
291
339
|
sigmas = torch.from_numpy(sigmas).to(device=device)
|
@@ -354,6 +402,60 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
354
402
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
355
403
|
return sigmas
|
356
404
|
|
405
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
406
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
407
|
+
"""Constructs an exponential noise schedule."""
|
408
|
+
|
409
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
410
|
+
# TODO: Add this logic to the other schedulers
|
411
|
+
if hasattr(self.config, "sigma_min"):
|
412
|
+
sigma_min = self.config.sigma_min
|
413
|
+
else:
|
414
|
+
sigma_min = None
|
415
|
+
|
416
|
+
if hasattr(self.config, "sigma_max"):
|
417
|
+
sigma_max = self.config.sigma_max
|
418
|
+
else:
|
419
|
+
sigma_max = None
|
420
|
+
|
421
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
422
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
423
|
+
|
424
|
+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
425
|
+
return sigmas
|
426
|
+
|
427
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
428
|
+
def _convert_to_beta(
|
429
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
430
|
+
) -> torch.Tensor:
|
431
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
432
|
+
|
433
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
434
|
+
# TODO: Add this logic to the other schedulers
|
435
|
+
if hasattr(self.config, "sigma_min"):
|
436
|
+
sigma_min = self.config.sigma_min
|
437
|
+
else:
|
438
|
+
sigma_min = None
|
439
|
+
|
440
|
+
if hasattr(self.config, "sigma_max"):
|
441
|
+
sigma_max = self.config.sigma_max
|
442
|
+
else:
|
443
|
+
sigma_max = None
|
444
|
+
|
445
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
446
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
447
|
+
|
448
|
+
sigmas = torch.Tensor(
|
449
|
+
[
|
450
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
451
|
+
for ppf in [
|
452
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
453
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
454
|
+
]
|
455
|
+
]
|
456
|
+
)
|
457
|
+
return sigmas
|
458
|
+
|
357
459
|
@property
|
358
460
|
def state_in_first_order(self):
|
359
461
|
return self.dt is None
|
@@ -373,7 +475,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
373
475
|
timestep: Union[float, torch.Tensor],
|
374
476
|
sample: Union[torch.Tensor, np.ndarray],
|
375
477
|
return_dict: bool = True,
|
376
|
-
) -> Union[
|
478
|
+
) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
|
377
479
|
"""
|
378
480
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
379
481
|
process from the learned model outputs (most often the predicted noise).
|
@@ -386,12 +488,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
386
488
|
sample (`torch.Tensor`):
|
387
489
|
A current instance of a sample created by the diffusion process.
|
388
490
|
return_dict (`bool`):
|
389
|
-
Whether or not to return a [`~schedulers.
|
491
|
+
Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
|
492
|
+
tuple.
|
390
493
|
|
391
494
|
Returns:
|
392
|
-
[`~schedulers.
|
393
|
-
If return_dict is `True`, [`~schedulers.
|
394
|
-
tuple is returned where the first element is the sample tensor.
|
495
|
+
[`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
|
496
|
+
If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
|
497
|
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
395
498
|
"""
|
396
499
|
if self.step_index is None:
|
397
500
|
self._init_step_index(timestep)
|
@@ -462,9 +565,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
462
565
|
self._step_index += 1
|
463
566
|
|
464
567
|
if not return_dict:
|
465
|
-
return (
|
568
|
+
return (
|
569
|
+
prev_sample,
|
570
|
+
pred_original_sample,
|
571
|
+
)
|
466
572
|
|
467
|
-
return
|
573
|
+
return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
468
574
|
|
469
575
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
470
576
|
def add_noise(
|
@@ -13,14 +13,39 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
+
from dataclasses import dataclass
|
16
17
|
from typing import List, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import numpy as np
|
19
20
|
import torch
|
20
21
|
|
21
22
|
from ..configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ..utils import BaseOutput, is_scipy_available
|
22
24
|
from ..utils.torch_utils import randn_tensor
|
23
|
-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
25
|
+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
26
|
+
|
27
|
+
|
28
|
+
if is_scipy_available():
|
29
|
+
import scipy.stats
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2AncestralDiscrete
|
34
|
+
class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
|
35
|
+
"""
|
36
|
+
Output class for the scheduler's `step` function output.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
40
|
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
41
|
+
denoising loop.
|
42
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
43
|
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
44
|
+
`pred_original_sample` can be used to preview progress or for guidance.
|
45
|
+
"""
|
46
|
+
|
47
|
+
prev_sample: torch.Tensor
|
48
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
24
49
|
|
25
50
|
|
26
51
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -91,6 +116,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
91
116
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
92
117
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
93
118
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
119
|
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
120
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
121
|
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
122
|
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
123
|
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
94
124
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
95
125
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
96
126
|
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
@@ -114,10 +144,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
114
144
|
beta_schedule: str = "linear",
|
115
145
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
116
146
|
use_karras_sigmas: Optional[bool] = False,
|
147
|
+
use_exponential_sigmas: Optional[bool] = False,
|
148
|
+
use_beta_sigmas: Optional[bool] = False,
|
117
149
|
prediction_type: str = "epsilon",
|
118
150
|
timestep_spacing: str = "linspace",
|
119
151
|
steps_offset: int = 0,
|
120
152
|
):
|
153
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
154
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
155
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
156
|
+
raise ValueError(
|
157
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
158
|
+
)
|
121
159
|
if trained_betas is not None:
|
122
160
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
123
161
|
elif beta_schedule == "linear":
|
@@ -250,6 +288,12 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
250
288
|
if self.config.use_karras_sigmas:
|
251
289
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
252
290
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
291
|
+
elif self.config.use_exponential_sigmas:
|
292
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
293
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
294
|
+
elif self.config.use_beta_sigmas:
|
295
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
296
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
253
297
|
|
254
298
|
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
|
255
299
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
@@ -346,6 +390,60 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
346
390
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
347
391
|
return sigmas
|
348
392
|
|
393
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
394
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
395
|
+
"""Constructs an exponential noise schedule."""
|
396
|
+
|
397
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
398
|
+
# TODO: Add this logic to the other schedulers
|
399
|
+
if hasattr(self.config, "sigma_min"):
|
400
|
+
sigma_min = self.config.sigma_min
|
401
|
+
else:
|
402
|
+
sigma_min = None
|
403
|
+
|
404
|
+
if hasattr(self.config, "sigma_max"):
|
405
|
+
sigma_max = self.config.sigma_max
|
406
|
+
else:
|
407
|
+
sigma_max = None
|
408
|
+
|
409
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
410
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
411
|
+
|
412
|
+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
413
|
+
return sigmas
|
414
|
+
|
415
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
416
|
+
def _convert_to_beta(
|
417
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
418
|
+
) -> torch.Tensor:
|
419
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
420
|
+
|
421
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
422
|
+
# TODO: Add this logic to the other schedulers
|
423
|
+
if hasattr(self.config, "sigma_min"):
|
424
|
+
sigma_min = self.config.sigma_min
|
425
|
+
else:
|
426
|
+
sigma_min = None
|
427
|
+
|
428
|
+
if hasattr(self.config, "sigma_max"):
|
429
|
+
sigma_max = self.config.sigma_max
|
430
|
+
else:
|
431
|
+
sigma_max = None
|
432
|
+
|
433
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
434
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
435
|
+
|
436
|
+
sigmas = torch.Tensor(
|
437
|
+
[
|
438
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
439
|
+
for ppf in [
|
440
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
441
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
442
|
+
]
|
443
|
+
]
|
444
|
+
)
|
445
|
+
return sigmas
|
446
|
+
|
349
447
|
@property
|
350
448
|
def state_in_first_order(self):
|
351
449
|
return self.sample is None
|
@@ -381,7 +479,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
381
479
|
sample: Union[torch.Tensor, np.ndarray],
|
382
480
|
generator: Optional[torch.Generator] = None,
|
383
481
|
return_dict: bool = True,
|
384
|
-
) -> Union[
|
482
|
+
) -> Union[KDPM2AncestralDiscreteSchedulerOutput, Tuple]:
|
385
483
|
"""
|
386
484
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
387
485
|
process from the learned model outputs (most often the predicted noise).
|
@@ -396,12 +494,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
396
494
|
generator (`torch.Generator`, *optional*):
|
397
495
|
A random number generator.
|
398
496
|
return_dict (`bool`):
|
399
|
-
Whether or not to return a
|
497
|
+
Whether or not to return a
|
498
|
+
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or tuple.
|
400
499
|
|
401
500
|
Returns:
|
402
|
-
[`~schedulers.
|
403
|
-
If return_dict is `True`,
|
404
|
-
|
501
|
+
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or `tuple`:
|
502
|
+
If return_dict is `True`,
|
503
|
+
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] is
|
504
|
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
405
505
|
"""
|
406
506
|
if self.step_index is None:
|
407
507
|
self._init_step_index(timestep)
|
@@ -424,9 +524,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
424
524
|
gamma = 0
|
425
525
|
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
426
526
|
|
427
|
-
device = model_output.device
|
428
|
-
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
|
429
|
-
|
430
527
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
431
528
|
if self.config.prediction_type == "epsilon":
|
432
529
|
sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
|
@@ -464,15 +561,23 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
464
561
|
self.sample = None
|
465
562
|
|
466
563
|
prev_sample = sample + derivative * dt
|
564
|
+
noise = randn_tensor(
|
565
|
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
566
|
+
)
|
467
567
|
prev_sample = prev_sample + noise * sigma_up
|
468
568
|
|
469
569
|
# upon completion increase step index by one
|
470
570
|
self._step_index += 1
|
471
571
|
|
472
572
|
if not return_dict:
|
473
|
-
return (
|
573
|
+
return (
|
574
|
+
prev_sample,
|
575
|
+
pred_original_sample,
|
576
|
+
)
|
474
577
|
|
475
|
-
return
|
578
|
+
return KDPM2AncestralDiscreteSchedulerOutput(
|
579
|
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
580
|
+
)
|
476
581
|
|
477
582
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
478
583
|
def add_noise(
|