diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -21,11 +21,15 @@ import numpy as np
21
21
  import torch
22
22
 
23
23
  from ..configuration_utils import ConfigMixin, register_to_config
24
- from ..utils import deprecate
24
+ from ..utils import deprecate, is_scipy_available
25
25
  from ..utils.torch_utils import randn_tensor
26
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
27
 
28
28
 
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
29
33
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
30
34
  def betas_for_alpha_bar(
31
35
  num_diffusion_timesteps,
@@ -161,6 +165,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
161
165
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
162
166
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
163
167
  the sigmas are determined according to a sequence of noise levels {σi}.
168
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
169
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
171
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
172
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
164
173
  use_lu_lambdas (`bool`, *optional*, defaults to `False`):
165
174
  Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
166
175
  the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
@@ -206,6 +215,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
206
215
  lower_order_final: bool = True,
207
216
  euler_at_final: bool = False,
208
217
  use_karras_sigmas: Optional[bool] = False,
218
+ use_exponential_sigmas: Optional[bool] = False,
219
+ use_beta_sigmas: Optional[bool] = False,
209
220
  use_lu_lambdas: Optional[bool] = False,
210
221
  final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
211
222
  lambda_min_clipped: float = -float("inf"),
@@ -214,6 +225,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
214
225
  steps_offset: int = 0,
215
226
  rescale_betas_zero_snr: bool = False,
216
227
  ):
228
+ if self.config.use_beta_sigmas and not is_scipy_available():
229
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
230
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
231
+ raise ValueError(
232
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
233
+ )
217
234
  if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
218
235
  deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
219
236
  deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -330,6 +347,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
330
347
  raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
331
348
  if timesteps is not None and self.config.use_lu_lambdas:
332
349
  raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
350
+ if timesteps is not None and self.config.use_exponential_sigmas:
351
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
352
+ if timesteps is not None and self.config.use_beta_sigmas:
353
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
333
354
 
334
355
  if timesteps is not None:
335
356
  timesteps = np.array(timesteps).astype(np.int64)
@@ -378,6 +399,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
378
399
  lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
379
400
  sigmas = np.exp(lambdas)
380
401
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
402
+ elif self.config.use_exponential_sigmas:
403
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
404
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
405
+ elif self.config.use_beta_sigmas:
406
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
407
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
381
408
  else:
382
409
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
383
410
 
@@ -510,6 +537,60 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
510
537
  lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
511
538
  return lambdas
512
539
 
540
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
541
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
542
+ """Constructs an exponential noise schedule."""
543
+
544
+ # Hack to make sure that other schedulers which copy this function don't break
545
+ # TODO: Add this logic to the other schedulers
546
+ if hasattr(self.config, "sigma_min"):
547
+ sigma_min = self.config.sigma_min
548
+ else:
549
+ sigma_min = None
550
+
551
+ if hasattr(self.config, "sigma_max"):
552
+ sigma_max = self.config.sigma_max
553
+ else:
554
+ sigma_max = None
555
+
556
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
557
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
558
+
559
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
560
+ return sigmas
561
+
562
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
563
+ def _convert_to_beta(
564
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
565
+ ) -> torch.Tensor:
566
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
567
+
568
+ # Hack to make sure that other schedulers which copy this function don't break
569
+ # TODO: Add this logic to the other schedulers
570
+ if hasattr(self.config, "sigma_min"):
571
+ sigma_min = self.config.sigma_min
572
+ else:
573
+ sigma_min = None
574
+
575
+ if hasattr(self.config, "sigma_max"):
576
+ sigma_max = self.config.sigma_max
577
+ else:
578
+ sigma_max = None
579
+
580
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
581
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
582
+
583
+ sigmas = torch.Tensor(
584
+ [
585
+ sigma_min + (ppf * (sigma_max - sigma_min))
586
+ for ppf in [
587
+ scipy.stats.beta.ppf(timestep, alpha, beta)
588
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
589
+ ]
590
+ ]
591
+ )
592
+ return sigmas
593
+
513
594
  def convert_model_output(
514
595
  self,
515
596
  model_output: torch.Tensor,
@@ -21,11 +21,15 @@ import numpy as np
21
21
  import torch
22
22
 
23
23
  from ..configuration_utils import ConfigMixin, register_to_config
24
- from ..utils import deprecate
24
+ from ..utils import deprecate, is_scipy_available
25
25
  from ..utils.torch_utils import randn_tensor
26
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
27
 
28
28
 
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
29
33
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
30
34
  def betas_for_alpha_bar(
31
35
  num_diffusion_timesteps,
@@ -124,6 +128,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
124
128
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
125
129
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
126
130
  the sigmas are determined according to a sequence of noise levels {σi}.
131
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
132
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
133
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
134
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
135
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
127
136
  lambda_min_clipped (`float`, defaults to `-inf`):
128
137
  Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
129
138
  cosine (`squaredcos_cap_v2`) noise schedule.
@@ -158,11 +167,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
158
167
  lower_order_final: bool = True,
159
168
  euler_at_final: bool = False,
160
169
  use_karras_sigmas: Optional[bool] = False,
170
+ use_exponential_sigmas: Optional[bool] = False,
171
+ use_beta_sigmas: Optional[bool] = False,
161
172
  lambda_min_clipped: float = -float("inf"),
162
173
  variance_type: Optional[str] = None,
163
174
  timestep_spacing: str = "linspace",
164
175
  steps_offset: int = 0,
165
176
  ):
177
+ if self.config.use_beta_sigmas and not is_scipy_available():
178
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
179
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
180
+ raise ValueError(
181
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
182
+ )
166
183
  if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
167
184
  deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
168
185
  deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -213,6 +230,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
213
230
  self._step_index = None
214
231
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
215
232
  self.use_karras_sigmas = use_karras_sigmas
233
+ self.use_exponential_sigmas = use_exponential_sigmas
234
+ self.use_beta_sigmas = use_beta_sigmas
216
235
 
217
236
  @property
218
237
  def step_index(self):
@@ -267,6 +286,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
267
286
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
268
287
  timesteps = timesteps.copy().astype(np.int64)
269
288
  sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
289
+ elif self.config.use_exponential_sigmas:
290
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
291
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
292
+ elif self.config.use_beta_sigmas:
293
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
294
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
270
295
  else:
271
296
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
272
297
  sigma_max = (
@@ -385,6 +410,60 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
385
410
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
386
411
  return sigmas
387
412
 
413
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
414
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
415
+ """Constructs an exponential noise schedule."""
416
+
417
+ # Hack to make sure that other schedulers which copy this function don't break
418
+ # TODO: Add this logic to the other schedulers
419
+ if hasattr(self.config, "sigma_min"):
420
+ sigma_min = self.config.sigma_min
421
+ else:
422
+ sigma_min = None
423
+
424
+ if hasattr(self.config, "sigma_max"):
425
+ sigma_max = self.config.sigma_max
426
+ else:
427
+ sigma_max = None
428
+
429
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
430
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
431
+
432
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
433
+ return sigmas
434
+
435
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
436
+ def _convert_to_beta(
437
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
438
+ ) -> torch.Tensor:
439
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
440
+
441
+ # Hack to make sure that other schedulers which copy this function don't break
442
+ # TODO: Add this logic to the other schedulers
443
+ if hasattr(self.config, "sigma_min"):
444
+ sigma_min = self.config.sigma_min
445
+ else:
446
+ sigma_min = None
447
+
448
+ if hasattr(self.config, "sigma_max"):
449
+ sigma_max = self.config.sigma_max
450
+ else:
451
+ sigma_max = None
452
+
453
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
454
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
455
+
456
+ sigmas = torch.Tensor(
457
+ [
458
+ sigma_min + (ppf * (sigma_max - sigma_min))
459
+ for ppf in [
460
+ scipy.stats.beta.ppf(timestep, alpha, beta)
461
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
462
+ ]
463
+ ]
464
+ )
465
+ return sigmas
466
+
388
467
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
389
468
  def convert_model_output(
390
469
  self,
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
+ from dataclasses import dataclass
16
17
  from typing import List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
@@ -20,7 +21,31 @@ import torch
20
21
  import torchsde
21
22
 
22
23
  from ..configuration_utils import ConfigMixin, register_to_config
23
- from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
24
+ from ..utils import BaseOutput, is_scipy_available
25
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
+
27
+
28
+ if is_scipy_available():
29
+ import scipy.stats
30
+
31
+
32
+ @dataclass
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
34
+ class DPMSolverSDESchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44
+ `pred_original_sample` can be used to preview progress or for guidance.
45
+ """
46
+
47
+ prev_sample: torch.Tensor
48
+ pred_original_sample: Optional[torch.Tensor] = None
24
49
 
25
50
 
26
51
  class BatchedBrownianTree:
@@ -38,7 +63,20 @@ class BatchedBrownianTree:
38
63
  except TypeError:
39
64
  seed = [seed]
40
65
  self.batched = False
41
- self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
66
+ self.trees = [
67
+ torchsde.BrownianInterval(
68
+ t0=t0,
69
+ t1=t1,
70
+ size=w0.shape,
71
+ dtype=w0.dtype,
72
+ device=w0.device,
73
+ entropy=s,
74
+ tol=1e-6,
75
+ pool_size=24,
76
+ halfway_tree=True,
77
+ )
78
+ for s in seed
79
+ ]
42
80
 
43
81
  @staticmethod
44
82
  def sort(a, b):
@@ -147,6 +185,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
147
185
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
148
186
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
149
187
  the sigmas are determined according to a sequence of noise levels {σi}.
188
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
189
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
190
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
191
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
192
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
150
193
  noise_sampler_seed (`int`, *optional*, defaults to `None`):
151
194
  The random seed to use for the noise sampler. If `None`, a random seed is generated.
152
195
  timestep_spacing (`str`, defaults to `"linspace"`):
@@ -169,10 +212,18 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
169
212
  trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
170
213
  prediction_type: str = "epsilon",
171
214
  use_karras_sigmas: Optional[bool] = False,
215
+ use_exponential_sigmas: Optional[bool] = False,
216
+ use_beta_sigmas: Optional[bool] = False,
172
217
  noise_sampler_seed: Optional[int] = None,
173
218
  timestep_spacing: str = "linspace",
174
219
  steps_offset: int = 0,
175
220
  ):
221
+ if self.config.use_beta_sigmas and not is_scipy_available():
222
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
223
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
224
+ raise ValueError(
225
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
226
+ )
176
227
  if trained_betas is not None:
177
228
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
178
229
  elif beta_schedule == "linear":
@@ -328,6 +379,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
328
379
  if self.config.use_karras_sigmas:
329
380
  sigmas = self._convert_to_karras(in_sigmas=sigmas)
330
381
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
382
+ elif self.config.use_exponential_sigmas:
383
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
384
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
385
+ elif self.config.use_beta_sigmas:
386
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
387
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
331
388
 
332
389
  second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
333
390
 
@@ -408,6 +465,60 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
408
465
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
409
466
  return sigmas
410
467
 
468
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
469
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
470
+ """Constructs an exponential noise schedule."""
471
+
472
+ # Hack to make sure that other schedulers which copy this function don't break
473
+ # TODO: Add this logic to the other schedulers
474
+ if hasattr(self.config, "sigma_min"):
475
+ sigma_min = self.config.sigma_min
476
+ else:
477
+ sigma_min = None
478
+
479
+ if hasattr(self.config, "sigma_max"):
480
+ sigma_max = self.config.sigma_max
481
+ else:
482
+ sigma_max = None
483
+
484
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
485
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
486
+
487
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
488
+ return sigmas
489
+
490
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
491
+ def _convert_to_beta(
492
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
493
+ ) -> torch.Tensor:
494
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
495
+
496
+ # Hack to make sure that other schedulers which copy this function don't break
497
+ # TODO: Add this logic to the other schedulers
498
+ if hasattr(self.config, "sigma_min"):
499
+ sigma_min = self.config.sigma_min
500
+ else:
501
+ sigma_min = None
502
+
503
+ if hasattr(self.config, "sigma_max"):
504
+ sigma_max = self.config.sigma_max
505
+ else:
506
+ sigma_max = None
507
+
508
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
509
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
510
+
511
+ sigmas = torch.Tensor(
512
+ [
513
+ sigma_min + (ppf * (sigma_max - sigma_min))
514
+ for ppf in [
515
+ scipy.stats.beta.ppf(timestep, alpha, beta)
516
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
517
+ ]
518
+ ]
519
+ )
520
+ return sigmas
521
+
411
522
  @property
412
523
  def state_in_first_order(self):
413
524
  return self.sample is None
@@ -419,7 +530,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
419
530
  sample: Union[torch.Tensor, np.ndarray],
420
531
  return_dict: bool = True,
421
532
  s_noise: float = 1.0,
422
- ) -> Union[SchedulerOutput, Tuple]:
533
+ ) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
423
534
  """
424
535
  Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
425
536
  process from the learned model outputs (most often the predicted noise).
@@ -431,15 +542,16 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
431
542
  The current discrete timestep in the diffusion chain.
432
543
  sample (`torch.Tensor` or `np.ndarray`):
433
544
  A current instance of a sample created by the diffusion process.
434
- return_dict (`bool`, *optional*, defaults to `True`):
435
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
545
+ return_dict (`bool`):
546
+ Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
547
+ tuple.
436
548
  s_noise (`float`, *optional*, defaults to 1.0):
437
549
  Scaling factor for noise added to the sample.
438
550
 
439
551
  Returns:
440
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
441
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
442
- tuple is returned where the first element is the sample tensor.
552
+ [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
553
+ If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
554
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
443
555
  """
444
556
  if self.step_index is None:
445
557
  self._init_step_index(timestep)
@@ -519,9 +631,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
519
631
  self._step_index += 1
520
632
 
521
633
  if not return_dict:
522
- return (prev_sample,)
634
+ return (
635
+ prev_sample,
636
+ pred_original_sample,
637
+ )
523
638
 
524
- return SchedulerOutput(prev_sample=prev_sample)
639
+ return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
525
640
 
526
641
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
527
642
  def add_noise(
@@ -21,11 +21,14 @@ import numpy as np
21
21
  import torch
22
22
 
23
23
  from ..configuration_utils import ConfigMixin, register_to_config
24
- from ..utils import deprecate, logging
24
+ from ..utils import deprecate, is_scipy_available, logging
25
25
  from ..utils.torch_utils import randn_tensor
26
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
27
 
28
28
 
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
29
32
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
33
 
31
34
 
@@ -123,6 +126,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
123
126
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
124
127
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
125
128
  the sigmas are determined according to a sequence of noise levels {σi}.
129
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
130
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
131
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
132
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
133
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
126
134
  final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
127
135
  The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
128
136
  sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
@@ -154,10 +162,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
154
162
  solver_type: str = "midpoint",
155
163
  lower_order_final: bool = False,
156
164
  use_karras_sigmas: Optional[bool] = False,
165
+ use_exponential_sigmas: Optional[bool] = False,
166
+ use_beta_sigmas: Optional[bool] = False,
157
167
  final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
158
168
  lambda_min_clipped: float = -float("inf"),
159
169
  variance_type: Optional[str] = None,
160
170
  ):
171
+ if self.config.use_beta_sigmas and not is_scipy_available():
172
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
173
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
174
+ raise ValueError(
175
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
176
+ )
161
177
  if algorithm_type == "dpmsolver":
162
178
  deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
163
179
  deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
@@ -300,6 +316,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
300
316
  raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
301
317
  if timesteps is not None and self.config.use_karras_sigmas:
302
318
  raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
319
+ if timesteps is not None and self.config.use_exponential_sigmas:
320
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
321
+ if timesteps is not None and self.config.use_beta_sigmas:
322
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
303
323
 
304
324
  num_inference_steps = num_inference_steps or len(timesteps)
305
325
  self.num_inference_steps = num_inference_steps
@@ -310,6 +330,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
310
330
  # Clipping the minimum of all lambda(t) for numerical stability.
311
331
  # This is critical for cosine (squaredcos_cap_v2) noise schedule.
312
332
  clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
333
+ clipped_idx = clipped_idx.item()
313
334
  timesteps = (
314
335
  np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
315
336
  .round()[::-1][:-1]
@@ -323,6 +344,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
323
344
  sigmas = np.flip(sigmas).copy()
324
345
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
325
346
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
347
+ elif self.config.use_exponential_sigmas:
348
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
349
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
350
+ elif self.config.use_beta_sigmas:
351
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
352
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
326
353
  else:
327
354
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
328
355
 
@@ -452,6 +479,60 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
452
479
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
453
480
  return sigmas
454
481
 
482
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
483
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
484
+ """Constructs an exponential noise schedule."""
485
+
486
+ # Hack to make sure that other schedulers which copy this function don't break
487
+ # TODO: Add this logic to the other schedulers
488
+ if hasattr(self.config, "sigma_min"):
489
+ sigma_min = self.config.sigma_min
490
+ else:
491
+ sigma_min = None
492
+
493
+ if hasattr(self.config, "sigma_max"):
494
+ sigma_max = self.config.sigma_max
495
+ else:
496
+ sigma_max = None
497
+
498
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
499
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
500
+
501
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
502
+ return sigmas
503
+
504
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
505
+ def _convert_to_beta(
506
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
507
+ ) -> torch.Tensor:
508
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
509
+
510
+ # Hack to make sure that other schedulers which copy this function don't break
511
+ # TODO: Add this logic to the other schedulers
512
+ if hasattr(self.config, "sigma_min"):
513
+ sigma_min = self.config.sigma_min
514
+ else:
515
+ sigma_min = None
516
+
517
+ if hasattr(self.config, "sigma_max"):
518
+ sigma_max = self.config.sigma_max
519
+ else:
520
+ sigma_max = None
521
+
522
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
523
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
524
+
525
+ sigmas = torch.Tensor(
526
+ [
527
+ sigma_min + (ppf * (sigma_max - sigma_min))
528
+ for ppf in [
529
+ scipy.stats.beta.ppf(timestep, alpha, beta)
530
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
531
+ ]
532
+ ]
533
+ )
534
+ return sigmas
535
+
455
536
  def convert_model_output(
456
537
  self,
457
538
  model_output: torch.Tensor,
@@ -333,14 +333,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
333
333
 
334
334
  gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
335
335
 
336
- noise = randn_tensor(
337
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
338
- )
339
-
340
- eps = noise * s_noise
341
336
  sigma_hat = sigma * (gamma + 1)
342
337
 
343
338
  if gamma > 0:
339
+ noise = randn_tensor(
340
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
341
+ )
342
+ eps = noise * s_noise
344
343
  sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
345
344
 
346
345
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -360,7 +359,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
360
359
  self._step_index += 1
361
360
 
362
361
  if not return_dict:
363
- return (prev_sample,)
362
+ return (
363
+ prev_sample,
364
+ pred_original_sample,
365
+ )
364
366
 
365
367
  return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
366
368
 
@@ -435,7 +435,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
435
435
  self._step_index += 1
436
436
 
437
437
  if not return_dict:
438
- return (prev_sample,)
438
+ return (
439
+ prev_sample,
440
+ pred_original_sample,
441
+ )
439
442
 
440
443
  return EulerAncestralDiscreteSchedulerOutput(
441
444
  prev_sample=prev_sample, pred_original_sample=pred_original_sample