diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from dataclasses import dataclass
17
17
  from typing import List, Optional, Tuple, Union
18
18
 
19
19
  import numpy as np
20
+ import scipy.stats
20
21
  import torch
21
22
  from scipy import integrate
22
23
 
@@ -111,6 +112,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
111
112
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
112
113
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
113
114
  the sigmas are determined according to a sequence of noise levels {σi}.
115
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
116
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
117
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
118
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
119
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
114
120
  prediction_type (`str`, defaults to `epsilon`, *optional*):
115
121
  Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
116
122
  `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -134,10 +140,16 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
134
140
  beta_schedule: str = "linear",
135
141
  trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
136
142
  use_karras_sigmas: Optional[bool] = False,
143
+ use_exponential_sigmas: Optional[bool] = False,
144
+ use_beta_sigmas: Optional[bool] = False,
137
145
  prediction_type: str = "epsilon",
138
146
  timestep_spacing: str = "linspace",
139
147
  steps_offset: int = 0,
140
148
  ):
149
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
150
+ raise ValueError(
151
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
152
+ )
141
153
  if trained_betas is not None:
142
154
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
143
155
  elif beta_schedule == "linear":
@@ -289,6 +301,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
289
301
  if self.config.use_karras_sigmas:
290
302
  sigmas = self._convert_to_karras(in_sigmas=sigmas)
291
303
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
304
+ elif self.config.use_exponential_sigmas:
305
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
306
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
307
+ elif self.config.use_beta_sigmas:
308
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
309
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
292
310
 
293
311
  sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
294
312
 
@@ -362,6 +380,60 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
362
380
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
363
381
  return sigmas
364
382
 
383
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
384
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
385
+ """Constructs an exponential noise schedule."""
386
+
387
+ # Hack to make sure that other schedulers which copy this function don't break
388
+ # TODO: Add this logic to the other schedulers
389
+ if hasattr(self.config, "sigma_min"):
390
+ sigma_min = self.config.sigma_min
391
+ else:
392
+ sigma_min = None
393
+
394
+ if hasattr(self.config, "sigma_max"):
395
+ sigma_max = self.config.sigma_max
396
+ else:
397
+ sigma_max = None
398
+
399
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
400
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
401
+
402
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
403
+ return sigmas
404
+
405
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
406
+ def _convert_to_beta(
407
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
408
+ ) -> torch.Tensor:
409
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
410
+
411
+ # Hack to make sure that other schedulers which copy this function don't break
412
+ # TODO: Add this logic to the other schedulers
413
+ if hasattr(self.config, "sigma_min"):
414
+ sigma_min = self.config.sigma_min
415
+ else:
416
+ sigma_min = None
417
+
418
+ if hasattr(self.config, "sigma_max"):
419
+ sigma_max = self.config.sigma_max
420
+ else:
421
+ sigma_max = None
422
+
423
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
424
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
425
+
426
+ sigmas = np.array(
427
+ [
428
+ sigma_min + (ppf * (sigma_max - sigma_min))
429
+ for ppf in [
430
+ scipy.stats.beta.ppf(timestep, alpha, beta)
431
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
432
+ ]
433
+ ]
434
+ )
435
+ return sigmas
436
+
365
437
  def step(
366
438
  self,
367
439
  model_output: torch.Tensor,
@@ -435,7 +507,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
435
507
  self._step_index += 1
436
508
 
437
509
  if not return_dict:
438
- return (prev_sample,)
510
+ return (
511
+ prev_sample,
512
+ pred_original_sample,
513
+ )
439
514
 
440
515
  return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
441
516
 
@@ -319,7 +319,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
319
319
  prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
320
320
 
321
321
  # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
322
- prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
322
+ prev_known_part = (alpha_prod_t_prev**0.5) * original_image + (1 - alpha_prod_t_prev) * noise
323
323
 
324
324
  # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
325
325
  pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
@@ -22,11 +22,15 @@ import numpy as np
22
22
  import torch
23
23
 
24
24
  from ..configuration_utils import ConfigMixin, register_to_config
25
- from ..utils import deprecate
25
+ from ..utils import deprecate, is_scipy_available
26
26
  from ..utils.torch_utils import randn_tensor
27
27
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
28
28
 
29
29
 
30
+ if is_scipy_available():
31
+ import scipy.stats
32
+
33
+
30
34
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
31
35
  def betas_for_alpha_bar(
32
36
  num_diffusion_timesteps,
@@ -122,6 +126,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
122
126
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
123
127
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
124
128
  the sigmas are determined according to a sequence of noise levels {σi}.
129
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
130
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
131
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
132
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
133
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
125
134
  lambda_min_clipped (`float`, defaults to `-inf`):
126
135
  Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
127
136
  cosine (`squaredcos_cap_v2`) noise schedule.
@@ -156,11 +165,21 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
156
165
  algorithm_type: str = "data_prediction",
157
166
  lower_order_final: bool = True,
158
167
  use_karras_sigmas: Optional[bool] = False,
168
+ use_exponential_sigmas: Optional[bool] = False,
169
+ use_beta_sigmas: Optional[bool] = False,
170
+ use_flow_sigmas: Optional[bool] = False,
171
+ flow_shift: Optional[float] = 1.0,
159
172
  lambda_min_clipped: float = -float("inf"),
160
173
  variance_type: Optional[str] = None,
161
174
  timestep_spacing: str = "linspace",
162
175
  steps_offset: int = 0,
163
176
  ):
177
+ if self.config.use_beta_sigmas and not is_scipy_available():
178
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
179
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
180
+ raise ValueError(
181
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
182
+ )
164
183
  if trained_betas is not None:
165
184
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
166
185
  elif beta_schedule == "linear":
@@ -278,12 +297,28 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
278
297
  )
279
298
 
280
299
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
300
+ log_sigmas = np.log(sigmas)
281
301
  if self.config.use_karras_sigmas:
282
- log_sigmas = np.log(sigmas)
283
302
  sigmas = np.flip(sigmas).copy()
284
303
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
285
304
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
286
305
  sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
306
+ elif self.config.use_exponential_sigmas:
307
+ sigmas = np.flip(sigmas).copy()
308
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
309
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
310
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
311
+ elif self.config.use_beta_sigmas:
312
+ sigmas = np.flip(sigmas).copy()
313
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
314
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
315
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
316
+ elif self.config.use_flow_sigmas:
317
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
318
+ sigmas = 1.0 - alphas
319
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
320
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
321
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
287
322
  else:
288
323
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
289
324
  sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -364,8 +399,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
364
399
 
365
400
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
366
401
  def _sigma_to_alpha_sigma_t(self, sigma):
367
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
368
- sigma_t = sigma * alpha_t
402
+ if self.config.use_flow_sigmas:
403
+ alpha_t = 1 - sigma
404
+ sigma_t = sigma
405
+ else:
406
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
407
+ sigma_t = sigma * alpha_t
369
408
 
370
409
  return alpha_t, sigma_t
371
410
 
@@ -395,6 +434,60 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
395
434
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
396
435
  return sigmas
397
436
 
437
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
438
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
439
+ """Constructs an exponential noise schedule."""
440
+
441
+ # Hack to make sure that other schedulers which copy this function don't break
442
+ # TODO: Add this logic to the other schedulers
443
+ if hasattr(self.config, "sigma_min"):
444
+ sigma_min = self.config.sigma_min
445
+ else:
446
+ sigma_min = None
447
+
448
+ if hasattr(self.config, "sigma_max"):
449
+ sigma_max = self.config.sigma_max
450
+ else:
451
+ sigma_max = None
452
+
453
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
454
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
455
+
456
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
457
+ return sigmas
458
+
459
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
460
+ def _convert_to_beta(
461
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
462
+ ) -> torch.Tensor:
463
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
464
+
465
+ # Hack to make sure that other schedulers which copy this function don't break
466
+ # TODO: Add this logic to the other schedulers
467
+ if hasattr(self.config, "sigma_min"):
468
+ sigma_min = self.config.sigma_min
469
+ else:
470
+ sigma_min = None
471
+
472
+ if hasattr(self.config, "sigma_max"):
473
+ sigma_max = self.config.sigma_max
474
+ else:
475
+ sigma_max = None
476
+
477
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
478
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
479
+
480
+ sigmas = np.array(
481
+ [
482
+ sigma_min + (ppf * (sigma_max - sigma_min))
483
+ for ppf in [
484
+ scipy.stats.beta.ppf(timestep, alpha, beta)
485
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
486
+ ]
487
+ ]
488
+ )
489
+ return sigmas
490
+
398
491
  def convert_model_output(
399
492
  self,
400
493
  model_output: torch.Tensor,
@@ -450,10 +543,13 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
450
543
  x0_pred = model_output
451
544
  elif self.config.prediction_type == "v_prediction":
452
545
  x0_pred = alpha_t * sample - sigma_t * model_output
546
+ elif self.config.prediction_type == "flow_prediction":
547
+ sigma_t = self.sigmas[self.step_index]
548
+ x0_pred = sample - sigma_t * model_output
453
549
  else:
454
550
  raise ValueError(
455
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
456
- " `v_prediction` for the SASolverScheduler."
551
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
552
+ "`v_prediction`, or `flow_prediction` for the SASolverScheduler."
457
553
  )
458
554
 
459
555
  if self.config.thresholding:
@@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
680
680
 
681
681
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
682
682
  def previous_timestep(self, timestep):
683
- if self.custom_timesteps:
683
+ if self.custom_timesteps or self.num_inference_steps:
684
684
  index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
685
685
  if index == self.timesteps.shape[0] - 1:
686
686
  prev_t = torch.tensor(-1)
687
687
  else:
688
688
  prev_t = self.timesteps[index + 1]
689
689
  else:
690
- num_inference_steps = (
691
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
692
- )
693
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
694
-
690
+ prev_t = timestep - 1
695
691
  return prev_t
@@ -320,7 +320,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
320
320
  pred_prev_sample = pred_prev_sample + variance
321
321
 
322
322
  if not return_dict:
323
- return (pred_prev_sample,)
323
+ return (
324
+ pred_prev_sample,
325
+ pred_original_sample,
326
+ )
324
327
 
325
328
  return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
326
329
 
@@ -22,10 +22,14 @@ import numpy as np
22
22
  import torch
23
23
 
24
24
  from ..configuration_utils import ConfigMixin, register_to_config
25
- from ..utils import deprecate
25
+ from ..utils import deprecate, is_scipy_available
26
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
27
 
28
28
 
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
29
33
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
30
34
  def betas_for_alpha_bar(
31
35
  num_diffusion_timesteps,
@@ -159,6 +163,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
159
163
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
160
164
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
161
165
  the sigmas are determined according to a sequence of noise levels {σi}.
166
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
167
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
168
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
169
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
170
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
162
171
  timestep_spacing (`str`, defaults to `"linspace"`):
163
172
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
164
173
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -195,11 +204,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
195
204
  disable_corrector: List[int] = [],
196
205
  solver_p: SchedulerMixin = None,
197
206
  use_karras_sigmas: Optional[bool] = False,
207
+ use_exponential_sigmas: Optional[bool] = False,
208
+ use_beta_sigmas: Optional[bool] = False,
209
+ use_flow_sigmas: Optional[bool] = False,
210
+ flow_shift: Optional[float] = 1.0,
198
211
  timestep_spacing: str = "linspace",
199
212
  steps_offset: int = 0,
200
213
  final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
201
214
  rescale_betas_zero_snr: bool = False,
202
215
  ):
216
+ if self.config.use_beta_sigmas and not is_scipy_available():
217
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
218
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
219
+ raise ValueError(
220
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
221
+ )
203
222
  if trained_betas is not None:
204
223
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
205
224
  elif beta_schedule == "linear":
@@ -329,6 +348,48 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
329
348
  f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
330
349
  )
331
350
  sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
351
+ elif self.config.use_exponential_sigmas:
352
+ log_sigmas = np.log(sigmas)
353
+ sigmas = np.flip(sigmas).copy()
354
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
355
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
356
+ if self.config.final_sigmas_type == "sigma_min":
357
+ sigma_last = sigmas[-1]
358
+ elif self.config.final_sigmas_type == "zero":
359
+ sigma_last = 0
360
+ else:
361
+ raise ValueError(
362
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
363
+ )
364
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
365
+ elif self.config.use_beta_sigmas:
366
+ log_sigmas = np.log(sigmas)
367
+ sigmas = np.flip(sigmas).copy()
368
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
369
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
370
+ if self.config.final_sigmas_type == "sigma_min":
371
+ sigma_last = sigmas[-1]
372
+ elif self.config.final_sigmas_type == "zero":
373
+ sigma_last = 0
374
+ else:
375
+ raise ValueError(
376
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
377
+ )
378
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
379
+ elif self.config.use_flow_sigmas:
380
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
381
+ sigmas = 1.0 - alphas
382
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
383
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
384
+ if self.config.final_sigmas_type == "sigma_min":
385
+ sigma_last = sigmas[-1]
386
+ elif self.config.final_sigmas_type == "zero":
387
+ sigma_last = 0
388
+ else:
389
+ raise ValueError(
390
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
391
+ )
392
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
332
393
  else:
333
394
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
334
395
  if self.config.final_sigmas_type == "sigma_min":
@@ -419,8 +480,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
419
480
 
420
481
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
421
482
  def _sigma_to_alpha_sigma_t(self, sigma):
422
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
423
- sigma_t = sigma * alpha_t
483
+ if self.config.use_flow_sigmas:
484
+ alpha_t = 1 - sigma
485
+ sigma_t = sigma
486
+ else:
487
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
488
+ sigma_t = sigma * alpha_t
424
489
 
425
490
  return alpha_t, sigma_t
426
491
 
@@ -450,6 +515,60 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
450
515
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
451
516
  return sigmas
452
517
 
518
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
519
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
520
+ """Constructs an exponential noise schedule."""
521
+
522
+ # Hack to make sure that other schedulers which copy this function don't break
523
+ # TODO: Add this logic to the other schedulers
524
+ if hasattr(self.config, "sigma_min"):
525
+ sigma_min = self.config.sigma_min
526
+ else:
527
+ sigma_min = None
528
+
529
+ if hasattr(self.config, "sigma_max"):
530
+ sigma_max = self.config.sigma_max
531
+ else:
532
+ sigma_max = None
533
+
534
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
535
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
536
+
537
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
538
+ return sigmas
539
+
540
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
541
+ def _convert_to_beta(
542
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
543
+ ) -> torch.Tensor:
544
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
545
+
546
+ # Hack to make sure that other schedulers which copy this function don't break
547
+ # TODO: Add this logic to the other schedulers
548
+ if hasattr(self.config, "sigma_min"):
549
+ sigma_min = self.config.sigma_min
550
+ else:
551
+ sigma_min = None
552
+
553
+ if hasattr(self.config, "sigma_max"):
554
+ sigma_max = self.config.sigma_max
555
+ else:
556
+ sigma_max = None
557
+
558
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
559
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
560
+
561
+ sigmas = np.array(
562
+ [
563
+ sigma_min + (ppf * (sigma_max - sigma_min))
564
+ for ppf in [
565
+ scipy.stats.beta.ppf(timestep, alpha, beta)
566
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
567
+ ]
568
+ ]
569
+ )
570
+ return sigmas
571
+
453
572
  def convert_model_output(
454
573
  self,
455
574
  model_output: torch.Tensor,
@@ -495,10 +614,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
495
614
  x0_pred = model_output
496
615
  elif self.config.prediction_type == "v_prediction":
497
616
  x0_pred = alpha_t * sample - sigma_t * model_output
617
+ elif self.config.prediction_type == "flow_prediction":
618
+ sigma_t = self.sigmas[self.step_index]
619
+ x0_pred = sample - sigma_t * model_output
498
620
  else:
499
621
  raise ValueError(
500
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
501
- " `v_prediction` for the UniPCMultistepScheduler."
622
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
623
+ "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
502
624
  )
503
625
 
504
626
  if self.config.thresholding: