diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -21,11 +21,14 @@ import numpy as np
21
21
  import torch
22
22
 
23
23
  from ..configuration_utils import ConfigMixin, register_to_config
24
- from ..utils import deprecate, logging
24
+ from ..utils import deprecate, is_scipy_available, logging
25
25
  from ..utils.torch_utils import randn_tensor
26
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
27
 
28
28
 
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
29
32
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
33
 
31
34
 
@@ -123,6 +126,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
123
126
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
124
127
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
125
128
  the sigmas are determined according to a sequence of noise levels {σi}.
129
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
130
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
131
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
132
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
133
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
126
134
  final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
127
135
  The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
128
136
  sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
@@ -154,10 +162,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
154
162
  solver_type: str = "midpoint",
155
163
  lower_order_final: bool = False,
156
164
  use_karras_sigmas: Optional[bool] = False,
165
+ use_exponential_sigmas: Optional[bool] = False,
166
+ use_beta_sigmas: Optional[bool] = False,
167
+ use_flow_sigmas: Optional[bool] = False,
168
+ flow_shift: Optional[float] = 1.0,
157
169
  final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
158
170
  lambda_min_clipped: float = -float("inf"),
159
171
  variance_type: Optional[str] = None,
160
172
  ):
173
+ if self.config.use_beta_sigmas and not is_scipy_available():
174
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
175
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
176
+ raise ValueError(
177
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
178
+ )
161
179
  if algorithm_type == "dpmsolver":
162
180
  deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
163
181
  deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
@@ -248,6 +266,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
248
266
  orders = [1, 2] * (steps // 2)
249
267
  elif order == 1:
250
268
  orders = [1] * steps
269
+
270
+ if self.config.final_sigmas_type == "zero":
271
+ orders[-1] = 1
272
+
251
273
  return orders
252
274
 
253
275
  @property
@@ -300,6 +322,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
300
322
  raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
301
323
  if timesteps is not None and self.config.use_karras_sigmas:
302
324
  raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
325
+ if timesteps is not None and self.config.use_exponential_sigmas:
326
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
327
+ if timesteps is not None and self.config.use_beta_sigmas:
328
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
303
329
 
304
330
  num_inference_steps = num_inference_steps or len(timesteps)
305
331
  self.num_inference_steps = num_inference_steps
@@ -310,6 +336,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
310
336
  # Clipping the minimum of all lambda(t) for numerical stability.
311
337
  # This is critical for cosine (squaredcos_cap_v2) noise schedule.
312
338
  clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
339
+ clipped_idx = clipped_idx.item()
313
340
  timesteps = (
314
341
  np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
315
342
  .round()[::-1][:-1]
@@ -318,11 +345,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
318
345
  )
319
346
 
320
347
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
348
+ log_sigmas = np.log(sigmas)
321
349
  if self.config.use_karras_sigmas:
322
- log_sigmas = np.log(sigmas)
323
350
  sigmas = np.flip(sigmas).copy()
324
351
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
325
352
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
353
+ elif self.config.use_exponential_sigmas:
354
+ sigmas = np.flip(sigmas).copy()
355
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
356
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
357
+ elif self.config.use_beta_sigmas:
358
+ sigmas = np.flip(sigmas).copy()
359
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
360
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
361
+ elif self.config.use_flow_sigmas:
362
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
363
+ sigmas = 1.0 - alphas
364
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
365
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
326
366
  else:
327
367
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
328
368
 
@@ -421,8 +461,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
421
461
 
422
462
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
423
463
  def _sigma_to_alpha_sigma_t(self, sigma):
424
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
425
- sigma_t = sigma * alpha_t
464
+ if self.config.use_flow_sigmas:
465
+ alpha_t = 1 - sigma
466
+ sigma_t = sigma
467
+ else:
468
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
469
+ sigma_t = sigma * alpha_t
426
470
 
427
471
  return alpha_t, sigma_t
428
472
 
@@ -452,6 +496,60 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
452
496
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
453
497
  return sigmas
454
498
 
499
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
500
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
501
+ """Constructs an exponential noise schedule."""
502
+
503
+ # Hack to make sure that other schedulers which copy this function don't break
504
+ # TODO: Add this logic to the other schedulers
505
+ if hasattr(self.config, "sigma_min"):
506
+ sigma_min = self.config.sigma_min
507
+ else:
508
+ sigma_min = None
509
+
510
+ if hasattr(self.config, "sigma_max"):
511
+ sigma_max = self.config.sigma_max
512
+ else:
513
+ sigma_max = None
514
+
515
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
516
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
517
+
518
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
519
+ return sigmas
520
+
521
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
522
+ def _convert_to_beta(
523
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
524
+ ) -> torch.Tensor:
525
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
526
+
527
+ # Hack to make sure that other schedulers which copy this function don't break
528
+ # TODO: Add this logic to the other schedulers
529
+ if hasattr(self.config, "sigma_min"):
530
+ sigma_min = self.config.sigma_min
531
+ else:
532
+ sigma_min = None
533
+
534
+ if hasattr(self.config, "sigma_max"):
535
+ sigma_max = self.config.sigma_max
536
+ else:
537
+ sigma_max = None
538
+
539
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
540
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
541
+
542
+ sigmas = np.array(
543
+ [
544
+ sigma_min + (ppf * (sigma_max - sigma_min))
545
+ for ppf in [
546
+ scipy.stats.beta.ppf(timestep, alpha, beta)
547
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
548
+ ]
549
+ ]
550
+ )
551
+ return sigmas
552
+
455
553
  def convert_model_output(
456
554
  self,
457
555
  model_output: torch.Tensor,
@@ -508,10 +606,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
508
606
  sigma = self.sigmas[self.step_index]
509
607
  alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
510
608
  x0_pred = alpha_t * sample - sigma_t * model_output
609
+ elif self.config.prediction_type == "flow_prediction":
610
+ sigma_t = self.sigmas[self.step_index]
611
+ x0_pred = sample - sigma_t * model_output
511
612
  else:
512
613
  raise ValueError(
513
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
514
- " `v_prediction` for the DPMSolverSinglestepScheduler."
614
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
615
+ "`v_prediction`, or `flow_prediction` for the DPMSolverSinglestepScheduler."
515
616
  )
516
617
 
517
618
  if self.config.thresholding:
@@ -729,6 +830,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
729
830
  model_output_list: List[torch.Tensor],
730
831
  *args,
731
832
  sample: torch.Tensor = None,
833
+ noise: Optional[torch.Tensor] = None,
732
834
  **kwargs,
733
835
  ) -> torch.Tensor:
734
836
  """
@@ -826,6 +928,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
826
928
  - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
827
929
  - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
828
930
  )
931
+ elif self.config.algorithm_type == "sde-dpmsolver++":
932
+ assert noise is not None
933
+ if self.config.solver_type == "midpoint":
934
+ x_t = (
935
+ (sigma_t / sigma_s2 * torch.exp(-h)) * sample
936
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
937
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
938
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
939
+ )
940
+ elif self.config.solver_type == "heun":
941
+ x_t = (
942
+ (sigma_t / sigma_s2 * torch.exp(-h)) * sample
943
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
944
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
945
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h) ** 2 - 0.5)) * D2
946
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
947
+ )
829
948
  return x_t
830
949
 
831
950
  def singlestep_dpm_solver_update(
@@ -887,7 +1006,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
887
1006
  elif order == 2:
888
1007
  return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
889
1008
  elif order == 3:
890
- return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
1009
+ return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
891
1010
  else:
892
1011
  raise ValueError(f"Order must be 1, 2, 3, got {order}")
893
1012
 
@@ -333,14 +333,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
333
333
 
334
334
  gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
335
335
 
336
- noise = randn_tensor(
337
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
338
- )
339
-
340
- eps = noise * s_noise
341
336
  sigma_hat = sigma * (gamma + 1)
342
337
 
343
338
  if gamma > 0:
339
+ noise = randn_tensor(
340
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
341
+ )
342
+ eps = noise * s_noise
344
343
  sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
345
344
 
346
345
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -360,7 +359,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
360
359
  self._step_index += 1
361
360
 
362
361
  if not return_dict:
363
- return (prev_sample,)
362
+ return (
363
+ prev_sample,
364
+ pred_original_sample,
365
+ )
364
366
 
365
367
  return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
366
368
 
@@ -435,7 +435,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
435
435
  self._step_index += 1
436
436
 
437
437
  if not return_dict:
438
- return (prev_sample,)
438
+ return (
439
+ prev_sample,
440
+ pred_original_sample,
441
+ )
439
442
 
440
443
  return EulerAncestralDiscreteSchedulerOutput(
441
444
  prev_sample=prev_sample, pred_original_sample=pred_original_sample
@@ -20,11 +20,14 @@ import numpy as np
20
20
  import torch
21
21
 
22
22
  from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils import BaseOutput, logging
23
+ from ..utils import BaseOutput, is_scipy_available, logging
24
24
  from ..utils.torch_utils import randn_tensor
25
25
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
26
 
27
27
 
28
+ if is_scipy_available():
29
+ import scipy.stats
30
+
28
31
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
32
 
30
33
 
@@ -158,6 +161,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
158
161
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
159
162
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
160
163
  the sigmas are determined according to a sequence of noise levels {σi}.
164
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
165
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
166
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
167
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
168
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
161
169
  timestep_spacing (`str`, defaults to `"linspace"`):
162
170
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
163
171
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -186,6 +194,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
186
194
  prediction_type: str = "epsilon",
187
195
  interpolation_type: str = "linear",
188
196
  use_karras_sigmas: Optional[bool] = False,
197
+ use_exponential_sigmas: Optional[bool] = False,
198
+ use_beta_sigmas: Optional[bool] = False,
189
199
  sigma_min: Optional[float] = None,
190
200
  sigma_max: Optional[float] = None,
191
201
  timestep_spacing: str = "linspace",
@@ -194,6 +204,12 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
194
204
  rescale_betas_zero_snr: bool = False,
195
205
  final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
196
206
  ):
207
+ if self.config.use_beta_sigmas and not is_scipy_available():
208
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
209
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
210
+ raise ValueError(
211
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
212
+ )
197
213
  if trained_betas is not None:
198
214
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
199
215
  elif beta_schedule == "linear":
@@ -235,6 +251,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
235
251
 
236
252
  self.is_scale_input_called = False
237
253
  self.use_karras_sigmas = use_karras_sigmas
254
+ self.use_exponential_sigmas = use_exponential_sigmas
255
+ self.use_beta_sigmas = use_beta_sigmas
238
256
 
239
257
  self._step_index = None
240
258
  self._begin_index = None
@@ -332,6 +350,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
332
350
  raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
333
351
  if timesteps is not None and self.config.use_karras_sigmas:
334
352
  raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
353
+ if timesteps is not None and self.config.use_exponential_sigmas:
354
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
355
+ if timesteps is not None and self.config.use_beta_sigmas:
356
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
335
357
  if (
336
358
  timesteps is not None
337
359
  and self.config.timestep_type == "continuous"
@@ -396,6 +418,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
396
418
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
397
419
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
398
420
 
421
+ elif self.config.use_exponential_sigmas:
422
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
423
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
424
+
425
+ elif self.config.use_beta_sigmas:
426
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
427
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
428
+
399
429
  if self.config.final_sigmas_type == "sigma_min":
400
430
  sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
401
431
  elif self.config.final_sigmas_type == "zero":
@@ -468,6 +498,59 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
468
498
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
469
499
  return sigmas
470
500
 
501
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
502
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
503
+ """Constructs an exponential noise schedule."""
504
+
505
+ # Hack to make sure that other schedulers which copy this function don't break
506
+ # TODO: Add this logic to the other schedulers
507
+ if hasattr(self.config, "sigma_min"):
508
+ sigma_min = self.config.sigma_min
509
+ else:
510
+ sigma_min = None
511
+
512
+ if hasattr(self.config, "sigma_max"):
513
+ sigma_max = self.config.sigma_max
514
+ else:
515
+ sigma_max = None
516
+
517
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
518
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
519
+
520
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
521
+ return sigmas
522
+
523
+ def _convert_to_beta(
524
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
525
+ ) -> torch.Tensor:
526
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
527
+
528
+ # Hack to make sure that other schedulers which copy this function don't break
529
+ # TODO: Add this logic to the other schedulers
530
+ if hasattr(self.config, "sigma_min"):
531
+ sigma_min = self.config.sigma_min
532
+ else:
533
+ sigma_min = None
534
+
535
+ if hasattr(self.config, "sigma_max"):
536
+ sigma_max = self.config.sigma_max
537
+ else:
538
+ sigma_max = None
539
+
540
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
541
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
542
+
543
+ sigmas = np.array(
544
+ [
545
+ sigma_min + (ppf * (sigma_max - sigma_min))
546
+ for ppf in [
547
+ scipy.stats.beta.ppf(timestep, alpha, beta)
548
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
549
+ ]
550
+ ]
551
+ )
552
+ return sigmas
553
+
471
554
  def index_for_timestep(self, timestep, schedule_timesteps=None):
472
555
  if schedule_timesteps is None:
473
556
  schedule_timesteps = self.timesteps
@@ -555,14 +638,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
555
638
 
556
639
  gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
557
640
 
558
- noise = randn_tensor(
559
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
560
- )
561
-
562
- eps = noise * s_noise
563
641
  sigma_hat = sigma * (gamma + 1)
564
642
 
565
643
  if gamma > 0:
644
+ noise = randn_tensor(
645
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
646
+ )
647
+ eps = noise * s_noise
566
648
  sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
567
649
 
568
650
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -594,7 +676,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
594
676
  self._step_index += 1
595
677
 
596
678
  if not return_dict:
597
- return (prev_sample,)
679
+ return (
680
+ prev_sample,
681
+ pred_original_sample,
682
+ )
598
683
 
599
684
  return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
600
685
 
@@ -20,10 +20,13 @@ import numpy as np
20
20
  import torch
21
21
 
22
22
  from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils import BaseOutput, logging
23
+ from ..utils import BaseOutput, is_scipy_available, logging
24
24
  from .scheduling_utils import SchedulerMixin
25
25
 
26
26
 
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
27
30
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
31
 
29
32
 
@@ -71,7 +74,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
71
74
  max_shift: Optional[float] = 1.15,
72
75
  base_image_seq_len: Optional[int] = 256,
73
76
  max_image_seq_len: Optional[int] = 4096,
77
+ invert_sigmas: bool = False,
78
+ shift_terminal: Optional[float] = None,
79
+ use_karras_sigmas: Optional[bool] = False,
80
+ use_exponential_sigmas: Optional[bool] = False,
81
+ use_beta_sigmas: Optional[bool] = False,
74
82
  ):
83
+ if self.config.use_beta_sigmas and not is_scipy_available():
84
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
85
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
86
+ raise ValueError(
87
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
88
+ )
75
89
  timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
76
90
  timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
77
91
 
@@ -85,10 +99,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
85
99
  self._step_index = None
86
100
  self._begin_index = None
87
101
 
102
+ self._shift = shift
103
+
88
104
  self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
89
105
  self.sigma_min = self.sigmas[-1].item()
90
106
  self.sigma_max = self.sigmas[0].item()
91
107
 
108
+ @property
109
+ def shift(self):
110
+ """
111
+ The value used for shifting.
112
+ """
113
+ return self._shift
114
+
92
115
  @property
93
116
  def step_index(self):
94
117
  """
@@ -114,6 +137,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
114
137
  """
115
138
  self._begin_index = begin_index
116
139
 
140
+ def set_shift(self, shift: float):
141
+ self._shift = shift
142
+
117
143
  def scale_noise(
118
144
  self,
119
145
  sample: torch.FloatTensor,
@@ -168,6 +194,27 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
168
194
  def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
169
195
  return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
170
196
 
197
+ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
198
+ r"""
199
+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
200
+ value.
201
+
202
+ Reference:
203
+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
204
+
205
+ Args:
206
+ t (`torch.Tensor`):
207
+ A tensor of timesteps to be stretched and shifted.
208
+
209
+ Returns:
210
+ `torch.Tensor`:
211
+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
212
+ """
213
+ one_minus_z = 1 - t
214
+ scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
215
+ stretched_t = 1 - (one_minus_z / scale_factor)
216
+ return stretched_t
217
+
171
218
  def set_timesteps(
172
219
  self,
173
220
  num_inference_steps: int = None,
@@ -184,29 +231,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
184
231
  device (`str` or `torch.device`, *optional*):
185
232
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
186
233
  """
187
-
188
234
  if self.config.use_dynamic_shifting and mu is None:
189
235
  raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
190
236
 
191
237
  if sigmas is None:
192
- self.num_inference_steps = num_inference_steps
193
238
  timesteps = np.linspace(
194
239
  self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
195
240
  )
196
241
 
197
242
  sigmas = timesteps / self.config.num_train_timesteps
243
+ else:
244
+ sigmas = np.array(sigmas).astype(np.float32)
245
+ num_inference_steps = len(sigmas)
246
+ self.num_inference_steps = num_inference_steps
198
247
 
199
248
  if self.config.use_dynamic_shifting:
200
249
  sigmas = self.time_shift(mu, 1.0, sigmas)
201
250
  else:
202
- sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
251
+ sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
252
+
253
+ if self.config.shift_terminal:
254
+ sigmas = self.stretch_shift_to_terminal(sigmas)
255
+
256
+ if self.config.use_karras_sigmas:
257
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
258
+
259
+ elif self.config.use_exponential_sigmas:
260
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
261
+
262
+ elif self.config.use_beta_sigmas:
263
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
203
264
 
204
265
  sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
205
266
  timesteps = sigmas * self.config.num_train_timesteps
206
267
 
207
- self.timesteps = timesteps.to(device=device)
208
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
268
+ if self.config.invert_sigmas:
269
+ sigmas = 1.0 - sigmas
270
+ timesteps = sigmas * self.config.num_train_timesteps
271
+ sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
272
+ else:
273
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
209
274
 
275
+ self.timesteps = timesteps.to(device=device)
276
+ self.sigmas = sigmas
210
277
  self._step_index = None
211
278
  self._begin_index = None
212
279
 
@@ -307,5 +374,85 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
307
374
 
308
375
  return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
309
376
 
377
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
378
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
379
+ """Constructs the noise schedule of Karras et al. (2022)."""
380
+
381
+ # Hack to make sure that other schedulers which copy this function don't break
382
+ # TODO: Add this logic to the other schedulers
383
+ if hasattr(self.config, "sigma_min"):
384
+ sigma_min = self.config.sigma_min
385
+ else:
386
+ sigma_min = None
387
+
388
+ if hasattr(self.config, "sigma_max"):
389
+ sigma_max = self.config.sigma_max
390
+ else:
391
+ sigma_max = None
392
+
393
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
394
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
395
+
396
+ rho = 7.0 # 7.0 is the value used in the paper
397
+ ramp = np.linspace(0, 1, num_inference_steps)
398
+ min_inv_rho = sigma_min ** (1 / rho)
399
+ max_inv_rho = sigma_max ** (1 / rho)
400
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
401
+ return sigmas
402
+
403
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
404
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
405
+ """Constructs an exponential noise schedule."""
406
+
407
+ # Hack to make sure that other schedulers which copy this function don't break
408
+ # TODO: Add this logic to the other schedulers
409
+ if hasattr(self.config, "sigma_min"):
410
+ sigma_min = self.config.sigma_min
411
+ else:
412
+ sigma_min = None
413
+
414
+ if hasattr(self.config, "sigma_max"):
415
+ sigma_max = self.config.sigma_max
416
+ else:
417
+ sigma_max = None
418
+
419
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
420
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
421
+
422
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
423
+ return sigmas
424
+
425
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
426
+ def _convert_to_beta(
427
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
428
+ ) -> torch.Tensor:
429
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
430
+
431
+ # Hack to make sure that other schedulers which copy this function don't break
432
+ # TODO: Add this logic to the other schedulers
433
+ if hasattr(self.config, "sigma_min"):
434
+ sigma_min = self.config.sigma_min
435
+ else:
436
+ sigma_min = None
437
+
438
+ if hasattr(self.config, "sigma_max"):
439
+ sigma_max = self.config.sigma_max
440
+ else:
441
+ sigma_max = None
442
+
443
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
444
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
445
+
446
+ sigmas = np.array(
447
+ [
448
+ sigma_min + (ppf * (sigma_max - sigma_min))
449
+ for ppf in [
450
+ scipy.stats.beta.ppf(timestep, alpha, beta)
451
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
452
+ ]
453
+ ]
454
+ )
455
+ return sigmas
456
+
310
457
  def __len__(self):
311
458
  return self.config.num_train_timesteps