diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -21,11 +21,15 @@ import numpy as np
21
21
  import torch
22
22
 
23
23
  from ..configuration_utils import ConfigMixin, register_to_config
24
- from ..utils import deprecate
24
+ from ..utils import deprecate, is_scipy_available
25
25
  from ..utils.torch_utils import randn_tensor
26
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
27
 
28
28
 
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
29
33
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
30
34
  def betas_for_alpha_bar(
31
35
  num_diffusion_timesteps,
@@ -161,6 +165,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
161
165
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
162
166
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
163
167
  the sigmas are determined according to a sequence of noise levels {σi}.
168
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
169
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
171
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
172
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
164
173
  use_lu_lambdas (`bool`, *optional*, defaults to `False`):
165
174
  Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
166
175
  the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
@@ -206,7 +215,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
206
215
  lower_order_final: bool = True,
207
216
  euler_at_final: bool = False,
208
217
  use_karras_sigmas: Optional[bool] = False,
218
+ use_exponential_sigmas: Optional[bool] = False,
219
+ use_beta_sigmas: Optional[bool] = False,
209
220
  use_lu_lambdas: Optional[bool] = False,
221
+ use_flow_sigmas: Optional[bool] = False,
222
+ flow_shift: Optional[float] = 1.0,
210
223
  final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
211
224
  lambda_min_clipped: float = -float("inf"),
212
225
  variance_type: Optional[str] = None,
@@ -214,6 +227,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
214
227
  steps_offset: int = 0,
215
228
  rescale_betas_zero_snr: bool = False,
216
229
  ):
230
+ if self.config.use_beta_sigmas and not is_scipy_available():
231
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
232
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
233
+ raise ValueError(
234
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
235
+ )
217
236
  if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
218
237
  deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
219
238
  deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -330,6 +349,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
330
349
  raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
331
350
  if timesteps is not None and self.config.use_lu_lambdas:
332
351
  raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
352
+ if timesteps is not None and self.config.use_exponential_sigmas:
353
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
354
+ if timesteps is not None and self.config.use_beta_sigmas:
355
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
333
356
 
334
357
  if timesteps is not None:
335
358
  timesteps = np.array(timesteps).astype(np.int64)
@@ -378,6 +401,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
378
401
  lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
379
402
  sigmas = np.exp(lambdas)
380
403
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
404
+ elif self.config.use_exponential_sigmas:
405
+ sigmas = np.flip(sigmas).copy()
406
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
407
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
408
+ elif self.config.use_beta_sigmas:
409
+ sigmas = np.flip(sigmas).copy()
410
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
411
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
412
+ elif self.config.use_flow_sigmas:
413
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
414
+ sigmas = 1.0 - alphas
415
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
416
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
381
417
  else:
382
418
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
383
419
 
@@ -466,8 +502,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
466
502
  return t
467
503
 
468
504
  def _sigma_to_alpha_sigma_t(self, sigma):
469
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
470
- sigma_t = sigma * alpha_t
505
+ if self.config.use_flow_sigmas:
506
+ alpha_t = 1 - sigma
507
+ sigma_t = sigma
508
+ else:
509
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
510
+ sigma_t = sigma * alpha_t
471
511
 
472
512
  return alpha_t, sigma_t
473
513
 
@@ -510,6 +550,60 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
510
550
  lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
511
551
  return lambdas
512
552
 
553
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
554
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
555
+ """Constructs an exponential noise schedule."""
556
+
557
+ # Hack to make sure that other schedulers which copy this function don't break
558
+ # TODO: Add this logic to the other schedulers
559
+ if hasattr(self.config, "sigma_min"):
560
+ sigma_min = self.config.sigma_min
561
+ else:
562
+ sigma_min = None
563
+
564
+ if hasattr(self.config, "sigma_max"):
565
+ sigma_max = self.config.sigma_max
566
+ else:
567
+ sigma_max = None
568
+
569
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
570
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
571
+
572
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
573
+ return sigmas
574
+
575
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
576
+ def _convert_to_beta(
577
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
578
+ ) -> torch.Tensor:
579
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
580
+
581
+ # Hack to make sure that other schedulers which copy this function don't break
582
+ # TODO: Add this logic to the other schedulers
583
+ if hasattr(self.config, "sigma_min"):
584
+ sigma_min = self.config.sigma_min
585
+ else:
586
+ sigma_min = None
587
+
588
+ if hasattr(self.config, "sigma_max"):
589
+ sigma_max = self.config.sigma_max
590
+ else:
591
+ sigma_max = None
592
+
593
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
594
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
595
+
596
+ sigmas = np.array(
597
+ [
598
+ sigma_min + (ppf * (sigma_max - sigma_min))
599
+ for ppf in [
600
+ scipy.stats.beta.ppf(timestep, alpha, beta)
601
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
602
+ ]
603
+ ]
604
+ )
605
+ return sigmas
606
+
513
607
  def convert_model_output(
514
608
  self,
515
609
  model_output: torch.Tensor,
@@ -567,10 +661,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
567
661
  sigma = self.sigmas[self.step_index]
568
662
  alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
569
663
  x0_pred = alpha_t * sample - sigma_t * model_output
664
+ elif self.config.prediction_type == "flow_prediction":
665
+ sigma_t = self.sigmas[self.step_index]
666
+ x0_pred = sample - sigma_t * model_output
570
667
  else:
571
668
  raise ValueError(
572
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
573
- " `v_prediction` for the DPMSolverMultistepScheduler."
669
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
670
+ "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
574
671
  )
575
672
 
576
673
  if self.config.thresholding:
@@ -806,6 +903,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
806
903
  model_output_list: List[torch.Tensor],
807
904
  *args,
808
905
  sample: torch.Tensor = None,
906
+ noise: Optional[torch.Tensor] = None,
809
907
  **kwargs,
810
908
  ) -> torch.Tensor:
811
909
  """
@@ -884,6 +982,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
884
982
  - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
885
983
  - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
886
984
  )
985
+ elif self.config.algorithm_type == "sde-dpmsolver++":
986
+ assert noise is not None
987
+ x_t = (
988
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
989
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
990
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
991
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
992
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
993
+ )
887
994
  return x_t
888
995
 
889
996
  def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -990,7 +1097,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
990
1097
  elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
991
1098
  prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
992
1099
  else:
993
- prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
1100
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
994
1101
 
995
1102
  if self.lower_order_nums < self.config.solver_order:
996
1103
  self.lower_order_nums += 1
@@ -21,11 +21,15 @@ import numpy as np
21
21
  import torch
22
22
 
23
23
  from ..configuration_utils import ConfigMixin, register_to_config
24
- from ..utils import deprecate
24
+ from ..utils import deprecate, is_scipy_available
25
25
  from ..utils.torch_utils import randn_tensor
26
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
27
 
28
28
 
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
29
33
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
30
34
  def betas_for_alpha_bar(
31
35
  num_diffusion_timesteps,
@@ -124,6 +128,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
124
128
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
125
129
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
126
130
  the sigmas are determined according to a sequence of noise levels {σi}.
131
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
132
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
133
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
134
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
135
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
127
136
  lambda_min_clipped (`float`, defaults to `-inf`):
128
137
  Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
129
138
  cosine (`squaredcos_cap_v2`) noise schedule.
@@ -158,11 +167,21 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
158
167
  lower_order_final: bool = True,
159
168
  euler_at_final: bool = False,
160
169
  use_karras_sigmas: Optional[bool] = False,
170
+ use_exponential_sigmas: Optional[bool] = False,
171
+ use_beta_sigmas: Optional[bool] = False,
172
+ use_flow_sigmas: Optional[bool] = False,
173
+ flow_shift: Optional[float] = 1.0,
161
174
  lambda_min_clipped: float = -float("inf"),
162
175
  variance_type: Optional[str] = None,
163
176
  timestep_spacing: str = "linspace",
164
177
  steps_offset: int = 0,
165
178
  ):
179
+ if self.config.use_beta_sigmas and not is_scipy_available():
180
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
181
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
182
+ raise ValueError(
183
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
184
+ )
166
185
  if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
167
186
  deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
168
187
  deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -213,6 +232,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
213
232
  self._step_index = None
214
233
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
215
234
  self.use_karras_sigmas = use_karras_sigmas
235
+ self.use_exponential_sigmas = use_exponential_sigmas
236
+ self.use_beta_sigmas = use_beta_sigmas
216
237
 
217
238
  @property
218
239
  def step_index(self):
@@ -267,6 +288,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
267
288
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
268
289
  timesteps = timesteps.copy().astype(np.int64)
269
290
  sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
291
+ elif self.config.use_exponential_sigmas:
292
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
293
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
294
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
295
+ elif self.config.use_beta_sigmas:
296
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
297
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
298
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
299
+ elif self.config.use_flow_sigmas:
300
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
301
+ sigmas = 1.0 - alphas
302
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
303
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
304
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
270
305
  else:
271
306
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
272
307
  sigma_max = (
@@ -354,8 +389,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
354
389
 
355
390
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
356
391
  def _sigma_to_alpha_sigma_t(self, sigma):
357
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
358
- sigma_t = sigma * alpha_t
392
+ if self.config.use_flow_sigmas:
393
+ alpha_t = 1 - sigma
394
+ sigma_t = sigma
395
+ else:
396
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
397
+ sigma_t = sigma * alpha_t
359
398
 
360
399
  return alpha_t, sigma_t
361
400
 
@@ -385,6 +424,60 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
385
424
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
386
425
  return sigmas
387
426
 
427
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
428
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
429
+ """Constructs an exponential noise schedule."""
430
+
431
+ # Hack to make sure that other schedulers which copy this function don't break
432
+ # TODO: Add this logic to the other schedulers
433
+ if hasattr(self.config, "sigma_min"):
434
+ sigma_min = self.config.sigma_min
435
+ else:
436
+ sigma_min = None
437
+
438
+ if hasattr(self.config, "sigma_max"):
439
+ sigma_max = self.config.sigma_max
440
+ else:
441
+ sigma_max = None
442
+
443
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
444
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
445
+
446
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
447
+ return sigmas
448
+
449
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
450
+ def _convert_to_beta(
451
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
452
+ ) -> torch.Tensor:
453
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
454
+
455
+ # Hack to make sure that other schedulers which copy this function don't break
456
+ # TODO: Add this logic to the other schedulers
457
+ if hasattr(self.config, "sigma_min"):
458
+ sigma_min = self.config.sigma_min
459
+ else:
460
+ sigma_min = None
461
+
462
+ if hasattr(self.config, "sigma_max"):
463
+ sigma_max = self.config.sigma_max
464
+ else:
465
+ sigma_max = None
466
+
467
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
468
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
469
+
470
+ sigmas = np.array(
471
+ [
472
+ sigma_min + (ppf * (sigma_max - sigma_min))
473
+ for ppf in [
474
+ scipy.stats.beta.ppf(timestep, alpha, beta)
475
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
476
+ ]
477
+ ]
478
+ )
479
+ return sigmas
480
+
388
481
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
389
482
  def convert_model_output(
390
483
  self,
@@ -443,10 +536,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
443
536
  sigma = self.sigmas[self.step_index]
444
537
  alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
445
538
  x0_pred = alpha_t * sample - sigma_t * model_output
539
+ elif self.config.prediction_type == "flow_prediction":
540
+ sigma_t = self.sigmas[self.step_index]
541
+ x0_pred = sample - sigma_t * model_output
446
542
  else:
447
543
  raise ValueError(
448
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
449
- " `v_prediction` for the DPMSolverMultistepScheduler."
544
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
545
+ "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
450
546
  )
451
547
 
452
548
  if self.config.thresholding:
@@ -685,6 +781,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
685
781
  model_output_list: List[torch.Tensor],
686
782
  *args,
687
783
  sample: torch.Tensor = None,
784
+ noise: Optional[torch.Tensor] = None,
688
785
  **kwargs,
689
786
  ) -> torch.Tensor:
690
787
  """
@@ -763,6 +860,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
763
860
  - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
764
861
  - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
765
862
  )
863
+ elif self.config.algorithm_type == "sde-dpmsolver++":
864
+ assert noise is not None
865
+ x_t = (
866
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
867
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
868
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
869
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
870
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
871
+ )
766
872
  return x_t
767
873
 
768
874
  def _init_step_index(self, timestep):
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
+ from dataclasses import dataclass
16
17
  from typing import List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
@@ -20,7 +21,31 @@ import torch
20
21
  import torchsde
21
22
 
22
23
  from ..configuration_utils import ConfigMixin, register_to_config
23
- from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
24
+ from ..utils import BaseOutput, is_scipy_available
25
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
+
27
+
28
+ if is_scipy_available():
29
+ import scipy.stats
30
+
31
+
32
+ @dataclass
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
34
+ class DPMSolverSDESchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44
+ `pred_original_sample` can be used to preview progress or for guidance.
45
+ """
46
+
47
+ prev_sample: torch.Tensor
48
+ pred_original_sample: Optional[torch.Tensor] = None
24
49
 
25
50
 
26
51
  class BatchedBrownianTree:
@@ -38,7 +63,20 @@ class BatchedBrownianTree:
38
63
  except TypeError:
39
64
  seed = [seed]
40
65
  self.batched = False
41
- self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
66
+ self.trees = [
67
+ torchsde.BrownianInterval(
68
+ t0=t0,
69
+ t1=t1,
70
+ size=w0.shape,
71
+ dtype=w0.dtype,
72
+ device=w0.device,
73
+ entropy=s,
74
+ tol=1e-6,
75
+ pool_size=24,
76
+ halfway_tree=True,
77
+ )
78
+ for s in seed
79
+ ]
42
80
 
43
81
  @staticmethod
44
82
  def sort(a, b):
@@ -147,6 +185,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
147
185
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
148
186
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
149
187
  the sigmas are determined according to a sequence of noise levels {σi}.
188
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
189
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
190
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
191
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
192
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
150
193
  noise_sampler_seed (`int`, *optional*, defaults to `None`):
151
194
  The random seed to use for the noise sampler. If `None`, a random seed is generated.
152
195
  timestep_spacing (`str`, defaults to `"linspace"`):
@@ -169,10 +212,18 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
169
212
  trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
170
213
  prediction_type: str = "epsilon",
171
214
  use_karras_sigmas: Optional[bool] = False,
215
+ use_exponential_sigmas: Optional[bool] = False,
216
+ use_beta_sigmas: Optional[bool] = False,
172
217
  noise_sampler_seed: Optional[int] = None,
173
218
  timestep_spacing: str = "linspace",
174
219
  steps_offset: int = 0,
175
220
  ):
221
+ if self.config.use_beta_sigmas and not is_scipy_available():
222
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
223
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
224
+ raise ValueError(
225
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
226
+ )
176
227
  if trained_betas is not None:
177
228
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
178
229
  elif beta_schedule == "linear":
@@ -328,6 +379,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
328
379
  if self.config.use_karras_sigmas:
329
380
  sigmas = self._convert_to_karras(in_sigmas=sigmas)
330
381
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
382
+ elif self.config.use_exponential_sigmas:
383
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
384
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
385
+ elif self.config.use_beta_sigmas:
386
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
387
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
331
388
 
332
389
  second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
333
390
 
@@ -408,6 +465,60 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
408
465
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
409
466
  return sigmas
410
467
 
468
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
469
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
470
+ """Constructs an exponential noise schedule."""
471
+
472
+ # Hack to make sure that other schedulers which copy this function don't break
473
+ # TODO: Add this logic to the other schedulers
474
+ if hasattr(self.config, "sigma_min"):
475
+ sigma_min = self.config.sigma_min
476
+ else:
477
+ sigma_min = None
478
+
479
+ if hasattr(self.config, "sigma_max"):
480
+ sigma_max = self.config.sigma_max
481
+ else:
482
+ sigma_max = None
483
+
484
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
485
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
486
+
487
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
488
+ return sigmas
489
+
490
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
491
+ def _convert_to_beta(
492
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
493
+ ) -> torch.Tensor:
494
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
495
+
496
+ # Hack to make sure that other schedulers which copy this function don't break
497
+ # TODO: Add this logic to the other schedulers
498
+ if hasattr(self.config, "sigma_min"):
499
+ sigma_min = self.config.sigma_min
500
+ else:
501
+ sigma_min = None
502
+
503
+ if hasattr(self.config, "sigma_max"):
504
+ sigma_max = self.config.sigma_max
505
+ else:
506
+ sigma_max = None
507
+
508
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
509
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
510
+
511
+ sigmas = np.array(
512
+ [
513
+ sigma_min + (ppf * (sigma_max - sigma_min))
514
+ for ppf in [
515
+ scipy.stats.beta.ppf(timestep, alpha, beta)
516
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
517
+ ]
518
+ ]
519
+ )
520
+ return sigmas
521
+
411
522
  @property
412
523
  def state_in_first_order(self):
413
524
  return self.sample is None
@@ -419,7 +530,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
419
530
  sample: Union[torch.Tensor, np.ndarray],
420
531
  return_dict: bool = True,
421
532
  s_noise: float = 1.0,
422
- ) -> Union[SchedulerOutput, Tuple]:
533
+ ) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
423
534
  """
424
535
  Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
425
536
  process from the learned model outputs (most often the predicted noise).
@@ -431,15 +542,16 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
431
542
  The current discrete timestep in the diffusion chain.
432
543
  sample (`torch.Tensor` or `np.ndarray`):
433
544
  A current instance of a sample created by the diffusion process.
434
- return_dict (`bool`, *optional*, defaults to `True`):
435
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
545
+ return_dict (`bool`):
546
+ Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
547
+ tuple.
436
548
  s_noise (`float`, *optional*, defaults to 1.0):
437
549
  Scaling factor for noise added to the sample.
438
550
 
439
551
  Returns:
440
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
441
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
442
- tuple is returned where the first element is the sample tensor.
552
+ [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
553
+ If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
554
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
443
555
  """
444
556
  if self.step_index is None:
445
557
  self._init_step_index(timestep)
@@ -519,9 +631,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
519
631
  self._step_index += 1
520
632
 
521
633
  if not return_dict:
522
- return (prev_sample,)
634
+ return (
635
+ prev_sample,
636
+ pred_original_sample,
637
+ )
523
638
 
524
- return SchedulerOutput(prev_sample=prev_sample)
639
+ return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
525
640
 
526
641
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
527
642
  def add_noise(