diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -266,14 +266,13 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
266
266
 
267
267
  gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
268
268
 
269
- noise = randn_tensor(
270
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
271
- )
272
-
273
- eps = noise * s_noise
274
269
  sigma_hat = sigma * (gamma + 1)
275
270
 
276
271
  if gamma > 0:
272
+ noise = randn_tensor(
273
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
274
+ )
275
+ eps = noise * s_noise
277
276
  sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
278
277
 
279
278
  if self.state_in_first_order:
@@ -13,13 +13,38 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
+ from dataclasses import dataclass
16
17
  from typing import List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
19
20
  import torch
20
21
 
21
22
  from ..configuration_utils import ConfigMixin, register_to_config
22
- from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
23
+ from ..utils import BaseOutput, is_scipy_available
24
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
25
+
26
+
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
30
+
31
+ @dataclass
32
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
33
+ class HeunDiscreteSchedulerOutput(BaseOutput):
34
+ """
35
+ Output class for the scheduler's `step` function output.
36
+
37
+ Args:
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
+ denoising loop.
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
+ `pred_original_sample` can be used to preview progress or for guidance.
44
+ """
45
+
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
23
48
 
24
49
 
25
50
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -97,6 +122,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
97
122
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
98
123
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
99
124
  the sigmas are determined according to a sequence of noise levels {σi}.
125
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
126
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
127
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
128
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
129
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
100
130
  timestep_spacing (`str`, defaults to `"linspace"`):
101
131
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
102
132
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -117,11 +147,19 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
117
147
  trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
118
148
  prediction_type: str = "epsilon",
119
149
  use_karras_sigmas: Optional[bool] = False,
150
+ use_exponential_sigmas: Optional[bool] = False,
151
+ use_beta_sigmas: Optional[bool] = False,
120
152
  clip_sample: Optional[bool] = False,
121
153
  clip_sample_range: float = 1.0,
122
154
  timestep_spacing: str = "linspace",
123
155
  steps_offset: int = 0,
124
156
  ):
157
+ if self.config.use_beta_sigmas and not is_scipy_available():
158
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
159
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
160
+ raise ValueError(
161
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
162
+ )
125
163
  if trained_betas is not None:
126
164
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
127
165
  elif beta_schedule == "linear":
@@ -251,6 +289,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
251
289
  raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
252
290
  if timesteps is not None and self.config.use_karras_sigmas:
253
291
  raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
292
+ if timesteps is not None and self.config.use_exponential_sigmas:
293
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
294
+ if timesteps is not None and self.config.use_beta_sigmas:
295
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
254
296
 
255
297
  num_inference_steps = num_inference_steps or len(timesteps)
256
298
  self.num_inference_steps = num_inference_steps
@@ -286,6 +328,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
286
328
  if self.config.use_karras_sigmas:
287
329
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
288
330
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
331
+ elif self.config.use_exponential_sigmas:
332
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
333
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
334
+ elif self.config.use_beta_sigmas:
335
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
336
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
289
337
 
290
338
  sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
291
339
  sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -354,6 +402,60 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
354
402
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
355
403
  return sigmas
356
404
 
405
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
406
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
407
+ """Constructs an exponential noise schedule."""
408
+
409
+ # Hack to make sure that other schedulers which copy this function don't break
410
+ # TODO: Add this logic to the other schedulers
411
+ if hasattr(self.config, "sigma_min"):
412
+ sigma_min = self.config.sigma_min
413
+ else:
414
+ sigma_min = None
415
+
416
+ if hasattr(self.config, "sigma_max"):
417
+ sigma_max = self.config.sigma_max
418
+ else:
419
+ sigma_max = None
420
+
421
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
422
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
423
+
424
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
425
+ return sigmas
426
+
427
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
428
+ def _convert_to_beta(
429
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
430
+ ) -> torch.Tensor:
431
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
432
+
433
+ # Hack to make sure that other schedulers which copy this function don't break
434
+ # TODO: Add this logic to the other schedulers
435
+ if hasattr(self.config, "sigma_min"):
436
+ sigma_min = self.config.sigma_min
437
+ else:
438
+ sigma_min = None
439
+
440
+ if hasattr(self.config, "sigma_max"):
441
+ sigma_max = self.config.sigma_max
442
+ else:
443
+ sigma_max = None
444
+
445
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
446
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
447
+
448
+ sigmas = np.array(
449
+ [
450
+ sigma_min + (ppf * (sigma_max - sigma_min))
451
+ for ppf in [
452
+ scipy.stats.beta.ppf(timestep, alpha, beta)
453
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
454
+ ]
455
+ ]
456
+ )
457
+ return sigmas
458
+
357
459
  @property
358
460
  def state_in_first_order(self):
359
461
  return self.dt is None
@@ -373,7 +475,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
373
475
  timestep: Union[float, torch.Tensor],
374
476
  sample: Union[torch.Tensor, np.ndarray],
375
477
  return_dict: bool = True,
376
- ) -> Union[SchedulerOutput, Tuple]:
478
+ ) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
377
479
  """
378
480
  Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
379
481
  process from the learned model outputs (most often the predicted noise).
@@ -386,12 +488,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
386
488
  sample (`torch.Tensor`):
387
489
  A current instance of a sample created by the diffusion process.
388
490
  return_dict (`bool`):
389
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
491
+ Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
492
+ tuple.
390
493
 
391
494
  Returns:
392
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
393
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
394
- tuple is returned where the first element is the sample tensor.
495
+ [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
496
+ If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
497
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
395
498
  """
396
499
  if self.step_index is None:
397
500
  self._init_step_index(timestep)
@@ -462,9 +565,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
462
565
  self._step_index += 1
463
566
 
464
567
  if not return_dict:
465
- return (prev_sample,)
568
+ return (
569
+ prev_sample,
570
+ pred_original_sample,
571
+ )
466
572
 
467
- return SchedulerOutput(prev_sample=prev_sample)
573
+ return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
468
574
 
469
575
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
470
576
  def add_noise(
@@ -13,14 +13,39 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
+ from dataclasses import dataclass
16
17
  from typing import List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
19
20
  import torch
20
21
 
21
22
  from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput, is_scipy_available
22
24
  from ..utils.torch_utils import randn_tensor
23
- from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
+
27
+
28
+ if is_scipy_available():
29
+ import scipy.stats
30
+
31
+
32
+ @dataclass
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2AncestralDiscrete
34
+ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44
+ `pred_original_sample` can be used to preview progress or for guidance.
45
+ """
46
+
47
+ prev_sample: torch.Tensor
48
+ pred_original_sample: Optional[torch.Tensor] = None
24
49
 
25
50
 
26
51
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -91,6 +116,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
91
116
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
92
117
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
93
118
  the sigmas are determined according to a sequence of noise levels {σi}.
119
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
120
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
121
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
122
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
123
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
94
124
  prediction_type (`str`, defaults to `epsilon`, *optional*):
95
125
  Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
96
126
  `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -114,10 +144,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
114
144
  beta_schedule: str = "linear",
115
145
  trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
116
146
  use_karras_sigmas: Optional[bool] = False,
147
+ use_exponential_sigmas: Optional[bool] = False,
148
+ use_beta_sigmas: Optional[bool] = False,
117
149
  prediction_type: str = "epsilon",
118
150
  timestep_spacing: str = "linspace",
119
151
  steps_offset: int = 0,
120
152
  ):
153
+ if self.config.use_beta_sigmas and not is_scipy_available():
154
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
155
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
156
+ raise ValueError(
157
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
158
+ )
121
159
  if trained_betas is not None:
122
160
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
123
161
  elif beta_schedule == "linear":
@@ -250,6 +288,12 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
250
288
  if self.config.use_karras_sigmas:
251
289
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
252
290
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
291
+ elif self.config.use_exponential_sigmas:
292
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
293
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
294
+ elif self.config.use_beta_sigmas:
295
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
296
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
253
297
 
254
298
  self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
255
299
  sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -346,6 +390,60 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
346
390
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
347
391
  return sigmas
348
392
 
393
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
394
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
395
+ """Constructs an exponential noise schedule."""
396
+
397
+ # Hack to make sure that other schedulers which copy this function don't break
398
+ # TODO: Add this logic to the other schedulers
399
+ if hasattr(self.config, "sigma_min"):
400
+ sigma_min = self.config.sigma_min
401
+ else:
402
+ sigma_min = None
403
+
404
+ if hasattr(self.config, "sigma_max"):
405
+ sigma_max = self.config.sigma_max
406
+ else:
407
+ sigma_max = None
408
+
409
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
410
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
411
+
412
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
413
+ return sigmas
414
+
415
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
416
+ def _convert_to_beta(
417
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
418
+ ) -> torch.Tensor:
419
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
420
+
421
+ # Hack to make sure that other schedulers which copy this function don't break
422
+ # TODO: Add this logic to the other schedulers
423
+ if hasattr(self.config, "sigma_min"):
424
+ sigma_min = self.config.sigma_min
425
+ else:
426
+ sigma_min = None
427
+
428
+ if hasattr(self.config, "sigma_max"):
429
+ sigma_max = self.config.sigma_max
430
+ else:
431
+ sigma_max = None
432
+
433
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
434
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
435
+
436
+ sigmas = np.array(
437
+ [
438
+ sigma_min + (ppf * (sigma_max - sigma_min))
439
+ for ppf in [
440
+ scipy.stats.beta.ppf(timestep, alpha, beta)
441
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
442
+ ]
443
+ ]
444
+ )
445
+ return sigmas
446
+
349
447
  @property
350
448
  def state_in_first_order(self):
351
449
  return self.sample is None
@@ -381,7 +479,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
381
479
  sample: Union[torch.Tensor, np.ndarray],
382
480
  generator: Optional[torch.Generator] = None,
383
481
  return_dict: bool = True,
384
- ) -> Union[SchedulerOutput, Tuple]:
482
+ ) -> Union[KDPM2AncestralDiscreteSchedulerOutput, Tuple]:
385
483
  """
386
484
  Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
387
485
  process from the learned model outputs (most often the predicted noise).
@@ -396,12 +494,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
396
494
  generator (`torch.Generator`, *optional*):
397
495
  A random number generator.
398
496
  return_dict (`bool`):
399
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
497
+ Whether or not to return a
498
+ [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or tuple.
400
499
 
401
500
  Returns:
402
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
403
- If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a
404
- tuple is returned where the first element is the sample tensor.
501
+ [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or `tuple`:
502
+ If return_dict is `True`,
503
+ [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] is
504
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
405
505
  """
406
506
  if self.step_index is None:
407
507
  self._init_step_index(timestep)
@@ -424,9 +524,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
424
524
  gamma = 0
425
525
  sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
426
526
 
427
- device = model_output.device
428
- noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
429
-
430
527
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
431
528
  if self.config.prediction_type == "epsilon":
432
529
  sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
@@ -464,15 +561,23 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
464
561
  self.sample = None
465
562
 
466
563
  prev_sample = sample + derivative * dt
564
+ noise = randn_tensor(
565
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
566
+ )
467
567
  prev_sample = prev_sample + noise * sigma_up
468
568
 
469
569
  # upon completion increase step index by one
470
570
  self._step_index += 1
471
571
 
472
572
  if not return_dict:
473
- return (prev_sample,)
573
+ return (
574
+ prev_sample,
575
+ pred_original_sample,
576
+ )
474
577
 
475
- return SchedulerOutput(prev_sample=prev_sample)
578
+ return KDPM2AncestralDiscreteSchedulerOutput(
579
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
580
+ )
476
581
 
477
582
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
478
583
  def add_noise(
@@ -13,13 +13,38 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
+ from dataclasses import dataclass
16
17
  from typing import List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
19
20
  import torch
20
21
 
21
22
  from ..configuration_utils import ConfigMixin, register_to_config
22
- from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
23
+ from ..utils import BaseOutput, is_scipy_available
24
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
25
+
26
+
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
30
+
31
+ @dataclass
32
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2Discrete
33
+ class KDPM2DiscreteSchedulerOutput(BaseOutput):
34
+ """
35
+ Output class for the scheduler's `step` function output.
36
+
37
+ Args:
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
+ denoising loop.
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
+ `pred_original_sample` can be used to preview progress or for guidance.
44
+ """
45
+
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
23
48
 
24
49
 
25
50
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -90,6 +115,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
90
115
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
91
116
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
92
117
  the sigmas are determined according to a sequence of noise levels {σi}.
118
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
119
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
120
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
121
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
122
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
93
123
  prediction_type (`str`, defaults to `epsilon`, *optional*):
94
124
  Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
95
125
  `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -113,10 +143,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
113
143
  beta_schedule: str = "linear",
114
144
  trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
115
145
  use_karras_sigmas: Optional[bool] = False,
146
+ use_exponential_sigmas: Optional[bool] = False,
147
+ use_beta_sigmas: Optional[bool] = False,
116
148
  prediction_type: str = "epsilon",
117
149
  timestep_spacing: str = "linspace",
118
150
  steps_offset: int = 0,
119
151
  ):
152
+ if self.config.use_beta_sigmas and not is_scipy_available():
153
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
154
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
155
+ raise ValueError(
156
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
157
+ )
120
158
  if trained_betas is not None:
121
159
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
122
160
  elif beta_schedule == "linear":
@@ -249,6 +287,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
249
287
  if self.config.use_karras_sigmas:
250
288
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
251
289
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
290
+ elif self.config.use_exponential_sigmas:
291
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
292
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
293
+ elif self.config.use_beta_sigmas:
294
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
295
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
252
296
 
253
297
  self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
254
298
  sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -359,13 +403,67 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
359
403
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
360
404
  return sigmas
361
405
 
406
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
407
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
408
+ """Constructs an exponential noise schedule."""
409
+
410
+ # Hack to make sure that other schedulers which copy this function don't break
411
+ # TODO: Add this logic to the other schedulers
412
+ if hasattr(self.config, "sigma_min"):
413
+ sigma_min = self.config.sigma_min
414
+ else:
415
+ sigma_min = None
416
+
417
+ if hasattr(self.config, "sigma_max"):
418
+ sigma_max = self.config.sigma_max
419
+ else:
420
+ sigma_max = None
421
+
422
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
423
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
424
+
425
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
426
+ return sigmas
427
+
428
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
429
+ def _convert_to_beta(
430
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
431
+ ) -> torch.Tensor:
432
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
433
+
434
+ # Hack to make sure that other schedulers which copy this function don't break
435
+ # TODO: Add this logic to the other schedulers
436
+ if hasattr(self.config, "sigma_min"):
437
+ sigma_min = self.config.sigma_min
438
+ else:
439
+ sigma_min = None
440
+
441
+ if hasattr(self.config, "sigma_max"):
442
+ sigma_max = self.config.sigma_max
443
+ else:
444
+ sigma_max = None
445
+
446
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
447
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
448
+
449
+ sigmas = np.array(
450
+ [
451
+ sigma_min + (ppf * (sigma_max - sigma_min))
452
+ for ppf in [
453
+ scipy.stats.beta.ppf(timestep, alpha, beta)
454
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
455
+ ]
456
+ ]
457
+ )
458
+ return sigmas
459
+
362
460
  def step(
363
461
  self,
364
462
  model_output: Union[torch.Tensor, np.ndarray],
365
463
  timestep: Union[float, torch.Tensor],
366
464
  sample: Union[torch.Tensor, np.ndarray],
367
465
  return_dict: bool = True,
368
- ) -> Union[SchedulerOutput, Tuple]:
466
+ ) -> Union[KDPM2DiscreteSchedulerOutput, Tuple]:
369
467
  """
370
468
  Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
371
469
  process from the learned model outputs (most often the predicted noise).
@@ -378,12 +476,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
378
476
  sample (`torch.Tensor`):
379
477
  A current instance of a sample created by the diffusion process.
380
478
  return_dict (`bool`):
381
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
479
+ Whether or not to return a [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or
480
+ tuple.
382
481
 
383
482
  Returns:
384
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
385
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
386
- tuple is returned where the first element is the sample tensor.
483
+ [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or `tuple`:
484
+ If return_dict is `True`, [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] is
485
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
387
486
  """
388
487
  if self.step_index is None:
389
488
  self._init_step_index(timestep)
@@ -445,9 +544,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
445
544
  prev_sample = sample + derivative * dt
446
545
 
447
546
  if not return_dict:
448
- return (prev_sample,)
547
+ return (
548
+ prev_sample,
549
+ pred_original_sample,
550
+ )
449
551
 
450
- return SchedulerOutput(prev_sample=prev_sample)
552
+ return KDPM2DiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
451
553
 
452
554
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
453
555
  def add_noise(
@@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
643
643
 
644
644
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
645
645
  def previous_timestep(self, timestep):
646
- if self.custom_timesteps:
646
+ if self.custom_timesteps or self.num_inference_steps:
647
647
  index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
648
648
  if index == self.timesteps.shape[0] - 1:
649
649
  prev_t = torch.tensor(-1)
650
650
  else:
651
651
  prev_t = self.timesteps[index + 1]
652
652
  else:
653
- num_inference_steps = (
654
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
655
- )
656
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
657
-
653
+ prev_t = timestep - 1
658
654
  return prev_t