diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -20,10 +20,13 @@ import numpy as np
20
20
  import torch
21
21
 
22
22
  from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils import BaseOutput, logging
23
+ from ..utils import BaseOutput, is_scipy_available, logging
24
24
  from .scheduling_utils import SchedulerMixin
25
25
 
26
26
 
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
27
30
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
31
 
29
32
 
@@ -71,7 +74,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
71
74
  max_shift: Optional[float] = 1.15,
72
75
  base_image_seq_len: Optional[int] = 256,
73
76
  max_image_seq_len: Optional[int] = 4096,
77
+ invert_sigmas: bool = False,
78
+ shift_terminal: Optional[float] = None,
79
+ use_karras_sigmas: Optional[bool] = False,
80
+ use_exponential_sigmas: Optional[bool] = False,
81
+ use_beta_sigmas: Optional[bool] = False,
74
82
  ):
83
+ if self.config.use_beta_sigmas and not is_scipy_available():
84
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
85
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
86
+ raise ValueError(
87
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
88
+ )
75
89
  timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
76
90
  timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
77
91
 
@@ -85,10 +99,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
85
99
  self._step_index = None
86
100
  self._begin_index = None
87
101
 
102
+ self._shift = shift
103
+
88
104
  self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
89
105
  self.sigma_min = self.sigmas[-1].item()
90
106
  self.sigma_max = self.sigmas[0].item()
91
107
 
108
+ @property
109
+ def shift(self):
110
+ """
111
+ The value used for shifting.
112
+ """
113
+ return self._shift
114
+
92
115
  @property
93
116
  def step_index(self):
94
117
  """
@@ -114,6 +137,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
114
137
  """
115
138
  self._begin_index = begin_index
116
139
 
140
+ def set_shift(self, shift: float):
141
+ self._shift = shift
142
+
117
143
  def scale_noise(
118
144
  self,
119
145
  sample: torch.FloatTensor,
@@ -168,6 +194,27 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
168
194
  def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
169
195
  return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
170
196
 
197
+ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
198
+ r"""
199
+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
200
+ value.
201
+
202
+ Reference:
203
+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
204
+
205
+ Args:
206
+ t (`torch.Tensor`):
207
+ A tensor of timesteps to be stretched and shifted.
208
+
209
+ Returns:
210
+ `torch.Tensor`:
211
+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
212
+ """
213
+ one_minus_z = 1 - t
214
+ scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
215
+ stretched_t = 1 - (one_minus_z / scale_factor)
216
+ return stretched_t
217
+
171
218
  def set_timesteps(
172
219
  self,
173
220
  num_inference_steps: int = None,
@@ -184,29 +231,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
184
231
  device (`str` or `torch.device`, *optional*):
185
232
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
186
233
  """
187
-
188
234
  if self.config.use_dynamic_shifting and mu is None:
189
235
  raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
190
236
 
191
237
  if sigmas is None:
192
- self.num_inference_steps = num_inference_steps
193
238
  timesteps = np.linspace(
194
239
  self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
195
240
  )
196
241
 
197
242
  sigmas = timesteps / self.config.num_train_timesteps
243
+ else:
244
+ sigmas = np.array(sigmas).astype(np.float32)
245
+ num_inference_steps = len(sigmas)
246
+ self.num_inference_steps = num_inference_steps
198
247
 
199
248
  if self.config.use_dynamic_shifting:
200
249
  sigmas = self.time_shift(mu, 1.0, sigmas)
201
250
  else:
202
- sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
251
+ sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
252
+
253
+ if self.config.shift_terminal:
254
+ sigmas = self.stretch_shift_to_terminal(sigmas)
255
+
256
+ if self.config.use_karras_sigmas:
257
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
258
+
259
+ elif self.config.use_exponential_sigmas:
260
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
261
+
262
+ elif self.config.use_beta_sigmas:
263
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
203
264
 
204
265
  sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
205
266
  timesteps = sigmas * self.config.num_train_timesteps
206
267
 
207
- self.timesteps = timesteps.to(device=device)
208
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
268
+ if self.config.invert_sigmas:
269
+ sigmas = 1.0 - sigmas
270
+ timesteps = sigmas * self.config.num_train_timesteps
271
+ sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
272
+ else:
273
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
209
274
 
275
+ self.timesteps = timesteps.to(device=device)
276
+ self.sigmas = sigmas
210
277
  self._step_index = None
211
278
  self._begin_index = None
212
279
 
@@ -307,5 +374,85 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
307
374
 
308
375
  return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
309
376
 
377
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
378
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
379
+ """Constructs the noise schedule of Karras et al. (2022)."""
380
+
381
+ # Hack to make sure that other schedulers which copy this function don't break
382
+ # TODO: Add this logic to the other schedulers
383
+ if hasattr(self.config, "sigma_min"):
384
+ sigma_min = self.config.sigma_min
385
+ else:
386
+ sigma_min = None
387
+
388
+ if hasattr(self.config, "sigma_max"):
389
+ sigma_max = self.config.sigma_max
390
+ else:
391
+ sigma_max = None
392
+
393
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
394
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
395
+
396
+ rho = 7.0 # 7.0 is the value used in the paper
397
+ ramp = np.linspace(0, 1, num_inference_steps)
398
+ min_inv_rho = sigma_min ** (1 / rho)
399
+ max_inv_rho = sigma_max ** (1 / rho)
400
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
401
+ return sigmas
402
+
403
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
404
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
405
+ """Constructs an exponential noise schedule."""
406
+
407
+ # Hack to make sure that other schedulers which copy this function don't break
408
+ # TODO: Add this logic to the other schedulers
409
+ if hasattr(self.config, "sigma_min"):
410
+ sigma_min = self.config.sigma_min
411
+ else:
412
+ sigma_min = None
413
+
414
+ if hasattr(self.config, "sigma_max"):
415
+ sigma_max = self.config.sigma_max
416
+ else:
417
+ sigma_max = None
418
+
419
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
420
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
421
+
422
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
423
+ return sigmas
424
+
425
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
426
+ def _convert_to_beta(
427
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
428
+ ) -> torch.Tensor:
429
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
430
+
431
+ # Hack to make sure that other schedulers which copy this function don't break
432
+ # TODO: Add this logic to the other schedulers
433
+ if hasattr(self.config, "sigma_min"):
434
+ sigma_min = self.config.sigma_min
435
+ else:
436
+ sigma_min = None
437
+
438
+ if hasattr(self.config, "sigma_max"):
439
+ sigma_max = self.config.sigma_max
440
+ else:
441
+ sigma_max = None
442
+
443
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
444
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
445
+
446
+ sigmas = np.array(
447
+ [
448
+ sigma_min + (ppf * (sigma_max - sigma_min))
449
+ for ppf in [
450
+ scipy.stats.beta.ppf(timestep, alpha, beta)
451
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
452
+ ]
453
+ ]
454
+ )
455
+ return sigmas
456
+
310
457
  def __len__(self):
311
458
  return self.config.num_train_timesteps
@@ -329,10 +329,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
329
329
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
330
330
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
331
331
  elif self.config.use_exponential_sigmas:
332
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
332
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
333
333
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
334
334
  elif self.config.use_beta_sigmas:
335
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
335
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
336
336
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
337
337
 
338
338
  sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -421,7 +421,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
421
421
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
422
422
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
423
423
 
424
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
424
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
425
425
  return sigmas
426
426
 
427
427
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -445,7 +445,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
445
445
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
446
446
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
447
447
 
448
- sigmas = torch.Tensor(
448
+ sigmas = np.array(
449
449
  [
450
450
  sigma_min + (ppf * (sigma_max - sigma_min))
451
451
  for ppf in [
@@ -289,10 +289,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
289
289
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
290
290
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
291
291
  elif self.config.use_exponential_sigmas:
292
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
292
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
293
293
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
294
294
  elif self.config.use_beta_sigmas:
295
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
295
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
296
296
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
297
297
 
298
298
  self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
@@ -409,7 +409,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
409
409
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
410
410
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
411
411
 
412
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
412
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
413
413
  return sigmas
414
414
 
415
415
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -433,7 +433,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
433
433
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
434
434
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
435
435
 
436
- sigmas = torch.Tensor(
436
+ sigmas = np.array(
437
437
  [
438
438
  sigma_min + (ppf * (sigma_max - sigma_min))
439
439
  for ppf in [
@@ -288,10 +288,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
288
288
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
289
289
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
290
290
  elif self.config.use_exponential_sigmas:
291
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
291
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
292
292
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
293
293
  elif self.config.use_beta_sigmas:
294
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
294
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
295
295
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
296
296
 
297
297
  self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
@@ -422,7 +422,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
422
422
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
423
423
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
424
424
 
425
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
425
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
426
426
  return sigmas
427
427
 
428
428
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -446,7 +446,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
446
446
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
447
447
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
448
448
 
449
- sigmas = torch.Tensor(
449
+ sigmas = np.array(
450
450
  [
451
451
  sigma_min + (ppf * (sigma_max - sigma_min))
452
452
  for ppf in [
@@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
643
643
 
644
644
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
645
645
  def previous_timestep(self, timestep):
646
- if self.custom_timesteps:
646
+ if self.custom_timesteps or self.num_inference_steps:
647
647
  index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
648
648
  if index == self.timesteps.shape[0] - 1:
649
649
  prev_t = torch.tensor(-1)
650
650
  else:
651
651
  prev_t = self.timesteps[index + 1]
652
652
  else:
653
- num_inference_steps = (
654
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
655
- )
656
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
657
-
653
+ prev_t = timestep - 1
658
654
  return prev_t
@@ -302,10 +302,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
302
302
  sigmas = self._convert_to_karras(in_sigmas=sigmas)
303
303
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
304
304
  elif self.config.use_exponential_sigmas:
305
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
305
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
306
306
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
307
307
  elif self.config.use_beta_sigmas:
308
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
308
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
309
309
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
310
310
 
311
311
  sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -399,7 +399,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
399
399
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
400
400
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
401
401
 
402
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
402
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
403
403
  return sigmas
404
404
 
405
405
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -423,7 +423,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
423
423
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
424
424
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
425
425
 
426
- sigmas = torch.Tensor(
426
+ sigmas = np.array(
427
427
  [
428
428
  sigma_min + (ppf * (sigma_max - sigma_min))
429
429
  for ppf in [
@@ -319,7 +319,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
319
319
  prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
320
320
 
321
321
  # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
322
- prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
322
+ prev_known_part = (alpha_prod_t_prev**0.5) * original_image + (1 - alpha_prod_t_prev) * noise
323
323
 
324
324
  # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
325
325
  pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
@@ -167,6 +167,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
167
167
  use_karras_sigmas: Optional[bool] = False,
168
168
  use_exponential_sigmas: Optional[bool] = False,
169
169
  use_beta_sigmas: Optional[bool] = False,
170
+ use_flow_sigmas: Optional[bool] = False,
171
+ flow_shift: Optional[float] = 1.0,
170
172
  lambda_min_clipped: float = -float("inf"),
171
173
  variance_type: Optional[str] = None,
172
174
  timestep_spacing: str = "linspace",
@@ -295,18 +297,28 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
295
297
  )
296
298
 
297
299
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
300
+ log_sigmas = np.log(sigmas)
298
301
  if self.config.use_karras_sigmas:
299
- log_sigmas = np.log(sigmas)
300
302
  sigmas = np.flip(sigmas).copy()
301
303
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
302
304
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
303
305
  sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
304
306
  elif self.config.use_exponential_sigmas:
305
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
307
+ sigmas = np.flip(sigmas).copy()
308
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
306
309
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
310
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
307
311
  elif self.config.use_beta_sigmas:
308
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
312
+ sigmas = np.flip(sigmas).copy()
313
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
309
314
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
315
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
316
+ elif self.config.use_flow_sigmas:
317
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
318
+ sigmas = 1.0 - alphas
319
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
320
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
321
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
310
322
  else:
311
323
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
312
324
  sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -387,8 +399,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
387
399
 
388
400
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
389
401
  def _sigma_to_alpha_sigma_t(self, sigma):
390
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
391
- sigma_t = sigma * alpha_t
402
+ if self.config.use_flow_sigmas:
403
+ alpha_t = 1 - sigma
404
+ sigma_t = sigma
405
+ else:
406
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
407
+ sigma_t = sigma * alpha_t
392
408
 
393
409
  return alpha_t, sigma_t
394
410
 
@@ -437,7 +453,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
437
453
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
438
454
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
439
455
 
440
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
456
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
441
457
  return sigmas
442
458
 
443
459
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -461,7 +477,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
461
477
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
462
478
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
463
479
 
464
- sigmas = torch.Tensor(
480
+ sigmas = np.array(
465
481
  [
466
482
  sigma_min + (ppf * (sigma_max - sigma_min))
467
483
  for ppf in [
@@ -527,10 +543,13 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
527
543
  x0_pred = model_output
528
544
  elif self.config.prediction_type == "v_prediction":
529
545
  x0_pred = alpha_t * sample - sigma_t * model_output
546
+ elif self.config.prediction_type == "flow_prediction":
547
+ sigma_t = self.sigmas[self.step_index]
548
+ x0_pred = sample - sigma_t * model_output
530
549
  else:
531
550
  raise ValueError(
532
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
533
- " `v_prediction` for the SASolverScheduler."
551
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
552
+ "`v_prediction`, or `flow_prediction` for the SASolverScheduler."
534
553
  )
535
554
 
536
555
  if self.config.thresholding:
@@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
680
680
 
681
681
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
682
682
  def previous_timestep(self, timestep):
683
- if self.custom_timesteps:
683
+ if self.custom_timesteps or self.num_inference_steps:
684
684
  index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
685
685
  if index == self.timesteps.shape[0] - 1:
686
686
  prev_t = torch.tensor(-1)
687
687
  else:
688
688
  prev_t = self.timesteps[index + 1]
689
689
  else:
690
- num_inference_steps = (
691
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
692
- )
693
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
694
-
690
+ prev_t = timestep - 1
695
691
  return prev_t
@@ -206,6 +206,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
206
206
  use_karras_sigmas: Optional[bool] = False,
207
207
  use_exponential_sigmas: Optional[bool] = False,
208
208
  use_beta_sigmas: Optional[bool] = False,
209
+ use_flow_sigmas: Optional[bool] = False,
210
+ flow_shift: Optional[float] = 1.0,
209
211
  timestep_spacing: str = "linspace",
210
212
  steps_offset: int = 0,
211
213
  final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
@@ -347,11 +349,47 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
347
349
  )
348
350
  sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
349
351
  elif self.config.use_exponential_sigmas:
350
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
352
+ log_sigmas = np.log(sigmas)
353
+ sigmas = np.flip(sigmas).copy()
354
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
351
355
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
356
+ if self.config.final_sigmas_type == "sigma_min":
357
+ sigma_last = sigmas[-1]
358
+ elif self.config.final_sigmas_type == "zero":
359
+ sigma_last = 0
360
+ else:
361
+ raise ValueError(
362
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
363
+ )
364
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
352
365
  elif self.config.use_beta_sigmas:
353
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
366
+ log_sigmas = np.log(sigmas)
367
+ sigmas = np.flip(sigmas).copy()
368
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
354
369
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
370
+ if self.config.final_sigmas_type == "sigma_min":
371
+ sigma_last = sigmas[-1]
372
+ elif self.config.final_sigmas_type == "zero":
373
+ sigma_last = 0
374
+ else:
375
+ raise ValueError(
376
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
377
+ )
378
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
379
+ elif self.config.use_flow_sigmas:
380
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
381
+ sigmas = 1.0 - alphas
382
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
383
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
384
+ if self.config.final_sigmas_type == "sigma_min":
385
+ sigma_last = sigmas[-1]
386
+ elif self.config.final_sigmas_type == "zero":
387
+ sigma_last = 0
388
+ else:
389
+ raise ValueError(
390
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
391
+ )
392
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
355
393
  else:
356
394
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
357
395
  if self.config.final_sigmas_type == "sigma_min":
@@ -442,8 +480,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
442
480
 
443
481
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
444
482
  def _sigma_to_alpha_sigma_t(self, sigma):
445
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
446
- sigma_t = sigma * alpha_t
483
+ if self.config.use_flow_sigmas:
484
+ alpha_t = 1 - sigma
485
+ sigma_t = sigma
486
+ else:
487
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
488
+ sigma_t = sigma * alpha_t
447
489
 
448
490
  return alpha_t, sigma_t
449
491
 
@@ -492,7 +534,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
492
534
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
493
535
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
494
536
 
495
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
537
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
496
538
  return sigmas
497
539
 
498
540
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -516,7 +558,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
516
558
  sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
517
559
  sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
518
560
 
519
- sigmas = torch.Tensor(
561
+ sigmas = np.array(
520
562
  [
521
563
  sigma_min + (ppf * (sigma_max - sigma_min))
522
564
  for ppf in [
@@ -572,10 +614,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
572
614
  x0_pred = model_output
573
615
  elif self.config.prediction_type == "v_prediction":
574
616
  x0_pred = alpha_t * sample - sigma_t * model_output
617
+ elif self.config.prediction_type == "flow_prediction":
618
+ sigma_t = self.sigmas[self.step_index]
619
+ x0_pred = sample - sigma_t * model_output
575
620
  else:
576
621
  raise ValueError(
577
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
578
- " `v_prediction` for the UniPCMultistepScheduler."
622
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
623
+ "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
579
624
  )
580
625
 
581
626
  if self.config.thresholding:
@@ -43,6 +43,9 @@ def set_seed(seed: int):
43
43
 
44
44
  Args:
45
45
  seed (`int`): The seed to set.
46
+
47
+ Returns:
48
+ `None`
46
49
  """
47
50
  random.seed(seed)
48
51
  np.random.seed(seed)
@@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
58
61
  """
59
62
  Computes SNR as per
60
63
  https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
64
+ for the given timesteps using the provided noise scheduler.
65
+
66
+ Args:
67
+ noise_scheduler (`NoiseScheduler`):
68
+ An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
69
+ the SNR values.
70
+ timesteps (`torch.Tensor`):
71
+ A tensor of timesteps for which the SNR is computed.
72
+
73
+ Returns:
74
+ `torch.Tensor`: A tensor containing the computed SNR values for each timestep.
61
75
  """
62
76
  alphas_cumprod = noise_scheduler.alphas_cumprod
63
77
  sqrt_alphas_cumprod = alphas_cumprod**0.5
@@ -284,7 +298,7 @@ def free_memory():
284
298
  elif torch.backends.mps.is_available():
285
299
  torch.mps.empty_cache()
286
300
  elif is_torch_npu_available():
287
- torch_npu.empty_cache()
301
+ torch_npu.npu.empty_cache()
288
302
 
289
303
 
290
304
  # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
@@ -379,7 +393,7 @@ class EMAModel:
379
393
 
380
394
  @classmethod
381
395
  def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
382
- _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
396
+ _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
383
397
  model = model_cls.from_pretrained(path)
384
398
 
385
399
  ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)