diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -20,10 +20,13 @@ import numpy as np
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from ..configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ..utils import BaseOutput, logging
|
23
|
+
from ..utils import BaseOutput, is_scipy_available, logging
|
24
24
|
from .scheduling_utils import SchedulerMixin
|
25
25
|
|
26
26
|
|
27
|
+
if is_scipy_available():
|
28
|
+
import scipy.stats
|
29
|
+
|
27
30
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28
31
|
|
29
32
|
|
@@ -71,7 +74,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
71
74
|
max_shift: Optional[float] = 1.15,
|
72
75
|
base_image_seq_len: Optional[int] = 256,
|
73
76
|
max_image_seq_len: Optional[int] = 4096,
|
77
|
+
invert_sigmas: bool = False,
|
78
|
+
shift_terminal: Optional[float] = None,
|
79
|
+
use_karras_sigmas: Optional[bool] = False,
|
80
|
+
use_exponential_sigmas: Optional[bool] = False,
|
81
|
+
use_beta_sigmas: Optional[bool] = False,
|
74
82
|
):
|
83
|
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
84
|
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
85
|
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
86
|
+
raise ValueError(
|
87
|
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
88
|
+
)
|
75
89
|
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
76
90
|
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
77
91
|
|
@@ -85,10 +99,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
85
99
|
self._step_index = None
|
86
100
|
self._begin_index = None
|
87
101
|
|
102
|
+
self._shift = shift
|
103
|
+
|
88
104
|
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
89
105
|
self.sigma_min = self.sigmas[-1].item()
|
90
106
|
self.sigma_max = self.sigmas[0].item()
|
91
107
|
|
108
|
+
@property
|
109
|
+
def shift(self):
|
110
|
+
"""
|
111
|
+
The value used for shifting.
|
112
|
+
"""
|
113
|
+
return self._shift
|
114
|
+
|
92
115
|
@property
|
93
116
|
def step_index(self):
|
94
117
|
"""
|
@@ -114,6 +137,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
114
137
|
"""
|
115
138
|
self._begin_index = begin_index
|
116
139
|
|
140
|
+
def set_shift(self, shift: float):
|
141
|
+
self._shift = shift
|
142
|
+
|
117
143
|
def scale_noise(
|
118
144
|
self,
|
119
145
|
sample: torch.FloatTensor,
|
@@ -168,6 +194,27 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
168
194
|
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
169
195
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
170
196
|
|
197
|
+
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
198
|
+
r"""
|
199
|
+
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
200
|
+
value.
|
201
|
+
|
202
|
+
Reference:
|
203
|
+
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
204
|
+
|
205
|
+
Args:
|
206
|
+
t (`torch.Tensor`):
|
207
|
+
A tensor of timesteps to be stretched and shifted.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
`torch.Tensor`:
|
211
|
+
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
212
|
+
"""
|
213
|
+
one_minus_z = 1 - t
|
214
|
+
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
215
|
+
stretched_t = 1 - (one_minus_z / scale_factor)
|
216
|
+
return stretched_t
|
217
|
+
|
171
218
|
def set_timesteps(
|
172
219
|
self,
|
173
220
|
num_inference_steps: int = None,
|
@@ -184,29 +231,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
184
231
|
device (`str` or `torch.device`, *optional*):
|
185
232
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
186
233
|
"""
|
187
|
-
|
188
234
|
if self.config.use_dynamic_shifting and mu is None:
|
189
235
|
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
190
236
|
|
191
237
|
if sigmas is None:
|
192
|
-
self.num_inference_steps = num_inference_steps
|
193
238
|
timesteps = np.linspace(
|
194
239
|
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
195
240
|
)
|
196
241
|
|
197
242
|
sigmas = timesteps / self.config.num_train_timesteps
|
243
|
+
else:
|
244
|
+
sigmas = np.array(sigmas).astype(np.float32)
|
245
|
+
num_inference_steps = len(sigmas)
|
246
|
+
self.num_inference_steps = num_inference_steps
|
198
247
|
|
199
248
|
if self.config.use_dynamic_shifting:
|
200
249
|
sigmas = self.time_shift(mu, 1.0, sigmas)
|
201
250
|
else:
|
202
|
-
sigmas = self.
|
251
|
+
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
252
|
+
|
253
|
+
if self.config.shift_terminal:
|
254
|
+
sigmas = self.stretch_shift_to_terminal(sigmas)
|
255
|
+
|
256
|
+
if self.config.use_karras_sigmas:
|
257
|
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
258
|
+
|
259
|
+
elif self.config.use_exponential_sigmas:
|
260
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
261
|
+
|
262
|
+
elif self.config.use_beta_sigmas:
|
263
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
203
264
|
|
204
265
|
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
205
266
|
timesteps = sigmas * self.config.num_train_timesteps
|
206
267
|
|
207
|
-
self.
|
208
|
-
|
268
|
+
if self.config.invert_sigmas:
|
269
|
+
sigmas = 1.0 - sigmas
|
270
|
+
timesteps = sigmas * self.config.num_train_timesteps
|
271
|
+
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
272
|
+
else:
|
273
|
+
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
209
274
|
|
275
|
+
self.timesteps = timesteps.to(device=device)
|
276
|
+
self.sigmas = sigmas
|
210
277
|
self._step_index = None
|
211
278
|
self._begin_index = None
|
212
279
|
|
@@ -307,5 +374,85 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
307
374
|
|
308
375
|
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
309
376
|
|
377
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
378
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
379
|
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
380
|
+
|
381
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
382
|
+
# TODO: Add this logic to the other schedulers
|
383
|
+
if hasattr(self.config, "sigma_min"):
|
384
|
+
sigma_min = self.config.sigma_min
|
385
|
+
else:
|
386
|
+
sigma_min = None
|
387
|
+
|
388
|
+
if hasattr(self.config, "sigma_max"):
|
389
|
+
sigma_max = self.config.sigma_max
|
390
|
+
else:
|
391
|
+
sigma_max = None
|
392
|
+
|
393
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
394
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
395
|
+
|
396
|
+
rho = 7.0 # 7.0 is the value used in the paper
|
397
|
+
ramp = np.linspace(0, 1, num_inference_steps)
|
398
|
+
min_inv_rho = sigma_min ** (1 / rho)
|
399
|
+
max_inv_rho = sigma_max ** (1 / rho)
|
400
|
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
401
|
+
return sigmas
|
402
|
+
|
403
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
404
|
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
405
|
+
"""Constructs an exponential noise schedule."""
|
406
|
+
|
407
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
408
|
+
# TODO: Add this logic to the other schedulers
|
409
|
+
if hasattr(self.config, "sigma_min"):
|
410
|
+
sigma_min = self.config.sigma_min
|
411
|
+
else:
|
412
|
+
sigma_min = None
|
413
|
+
|
414
|
+
if hasattr(self.config, "sigma_max"):
|
415
|
+
sigma_max = self.config.sigma_max
|
416
|
+
else:
|
417
|
+
sigma_max = None
|
418
|
+
|
419
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
420
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
421
|
+
|
422
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
423
|
+
return sigmas
|
424
|
+
|
425
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
426
|
+
def _convert_to_beta(
|
427
|
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
428
|
+
) -> torch.Tensor:
|
429
|
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
430
|
+
|
431
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
432
|
+
# TODO: Add this logic to the other schedulers
|
433
|
+
if hasattr(self.config, "sigma_min"):
|
434
|
+
sigma_min = self.config.sigma_min
|
435
|
+
else:
|
436
|
+
sigma_min = None
|
437
|
+
|
438
|
+
if hasattr(self.config, "sigma_max"):
|
439
|
+
sigma_max = self.config.sigma_max
|
440
|
+
else:
|
441
|
+
sigma_max = None
|
442
|
+
|
443
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
444
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
445
|
+
|
446
|
+
sigmas = np.array(
|
447
|
+
[
|
448
|
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
449
|
+
for ppf in [
|
450
|
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
451
|
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
452
|
+
]
|
453
|
+
]
|
454
|
+
)
|
455
|
+
return sigmas
|
456
|
+
|
310
457
|
def __len__(self):
|
311
458
|
return self.config.num_train_timesteps
|
@@ -329,10 +329,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
329
329
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
330
330
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
331
331
|
elif self.config.use_exponential_sigmas:
|
332
|
-
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=
|
332
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
333
333
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
334
334
|
elif self.config.use_beta_sigmas:
|
335
|
-
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=
|
335
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
336
336
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
337
337
|
|
338
338
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
@@ -421,7 +421,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
421
421
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
422
422
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
423
423
|
|
424
|
-
sigmas =
|
424
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
425
425
|
return sigmas
|
426
426
|
|
427
427
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
@@ -445,7 +445,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
445
445
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
446
446
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
447
447
|
|
448
|
-
sigmas =
|
448
|
+
sigmas = np.array(
|
449
449
|
[
|
450
450
|
sigma_min + (ppf * (sigma_max - sigma_min))
|
451
451
|
for ppf in [
|
@@ -289,10 +289,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
289
289
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
290
290
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
291
291
|
elif self.config.use_exponential_sigmas:
|
292
|
-
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=
|
292
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
293
293
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
294
294
|
elif self.config.use_beta_sigmas:
|
295
|
-
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=
|
295
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
296
296
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
297
297
|
|
298
298
|
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
|
@@ -409,7 +409,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
409
409
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
410
410
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
411
411
|
|
412
|
-
sigmas =
|
412
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
413
413
|
return sigmas
|
414
414
|
|
415
415
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
@@ -433,7 +433,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
433
433
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
434
434
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
435
435
|
|
436
|
-
sigmas =
|
436
|
+
sigmas = np.array(
|
437
437
|
[
|
438
438
|
sigma_min + (ppf * (sigma_max - sigma_min))
|
439
439
|
for ppf in [
|
@@ -288,10 +288,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
288
288
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
289
289
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
290
290
|
elif self.config.use_exponential_sigmas:
|
291
|
-
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=
|
291
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
292
292
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
293
293
|
elif self.config.use_beta_sigmas:
|
294
|
-
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=
|
294
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
295
295
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
296
296
|
|
297
297
|
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
|
@@ -422,7 +422,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
422
422
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
423
423
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
424
424
|
|
425
|
-
sigmas =
|
425
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
426
426
|
return sigmas
|
427
427
|
|
428
428
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
@@ -446,7 +446,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
446
446
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
447
447
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
448
448
|
|
449
|
-
sigmas =
|
449
|
+
sigmas = np.array(
|
450
450
|
[
|
451
451
|
sigma_min + (ppf * (sigma_max - sigma_min))
|
452
452
|
for ppf in [
|
@@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
643
643
|
|
644
644
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
645
645
|
def previous_timestep(self, timestep):
|
646
|
-
if self.custom_timesteps:
|
646
|
+
if self.custom_timesteps or self.num_inference_steps:
|
647
647
|
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
648
648
|
if index == self.timesteps.shape[0] - 1:
|
649
649
|
prev_t = torch.tensor(-1)
|
650
650
|
else:
|
651
651
|
prev_t = self.timesteps[index + 1]
|
652
652
|
else:
|
653
|
-
|
654
|
-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
655
|
-
)
|
656
|
-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
657
|
-
|
653
|
+
prev_t = timestep - 1
|
658
654
|
return prev_t
|
@@ -302,10 +302,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
302
302
|
sigmas = self._convert_to_karras(in_sigmas=sigmas)
|
303
303
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
304
304
|
elif self.config.use_exponential_sigmas:
|
305
|
-
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=
|
305
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
306
306
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
307
307
|
elif self.config.use_beta_sigmas:
|
308
|
-
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=
|
308
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
309
309
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
310
310
|
|
311
311
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
@@ -399,7 +399,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
399
399
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
400
400
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
401
401
|
|
402
|
-
sigmas =
|
402
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
403
403
|
return sigmas
|
404
404
|
|
405
405
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
@@ -423,7 +423,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
423
423
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
424
424
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
425
425
|
|
426
|
-
sigmas =
|
426
|
+
sigmas = np.array(
|
427
427
|
[
|
428
428
|
sigma_min + (ppf * (sigma_max - sigma_min))
|
429
429
|
for ppf in [
|
@@ -319,7 +319,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|
319
319
|
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
|
320
320
|
|
321
321
|
# 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
|
322
|
-
prev_known_part = (alpha_prod_t_prev**0.5) * original_image + (
|
322
|
+
prev_known_part = (alpha_prod_t_prev**0.5) * original_image + (1 - alpha_prod_t_prev) * noise
|
323
323
|
|
324
324
|
# 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
|
325
325
|
pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
|
@@ -167,6 +167,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
167
167
|
use_karras_sigmas: Optional[bool] = False,
|
168
168
|
use_exponential_sigmas: Optional[bool] = False,
|
169
169
|
use_beta_sigmas: Optional[bool] = False,
|
170
|
+
use_flow_sigmas: Optional[bool] = False,
|
171
|
+
flow_shift: Optional[float] = 1.0,
|
170
172
|
lambda_min_clipped: float = -float("inf"),
|
171
173
|
variance_type: Optional[str] = None,
|
172
174
|
timestep_spacing: str = "linspace",
|
@@ -295,18 +297,28 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
295
297
|
)
|
296
298
|
|
297
299
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
300
|
+
log_sigmas = np.log(sigmas)
|
298
301
|
if self.config.use_karras_sigmas:
|
299
|
-
log_sigmas = np.log(sigmas)
|
300
302
|
sigmas = np.flip(sigmas).copy()
|
301
303
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
302
304
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
303
305
|
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
304
306
|
elif self.config.use_exponential_sigmas:
|
305
|
-
sigmas =
|
307
|
+
sigmas = np.flip(sigmas).copy()
|
308
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
306
309
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
310
|
+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
307
311
|
elif self.config.use_beta_sigmas:
|
308
|
-
sigmas =
|
312
|
+
sigmas = np.flip(sigmas).copy()
|
313
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
309
314
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
315
|
+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
316
|
+
elif self.config.use_flow_sigmas:
|
317
|
+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
318
|
+
sigmas = 1.0 - alphas
|
319
|
+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
320
|
+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
321
|
+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
310
322
|
else:
|
311
323
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
312
324
|
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
@@ -387,8 +399,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
387
399
|
|
388
400
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
389
401
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
390
|
-
|
391
|
-
|
402
|
+
if self.config.use_flow_sigmas:
|
403
|
+
alpha_t = 1 - sigma
|
404
|
+
sigma_t = sigma
|
405
|
+
else:
|
406
|
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
407
|
+
sigma_t = sigma * alpha_t
|
392
408
|
|
393
409
|
return alpha_t, sigma_t
|
394
410
|
|
@@ -437,7 +453,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
437
453
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
438
454
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
439
455
|
|
440
|
-
sigmas =
|
456
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
441
457
|
return sigmas
|
442
458
|
|
443
459
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
@@ -461,7 +477,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
461
477
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
462
478
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
463
479
|
|
464
|
-
sigmas =
|
480
|
+
sigmas = np.array(
|
465
481
|
[
|
466
482
|
sigma_min + (ppf * (sigma_max - sigma_min))
|
467
483
|
for ppf in [
|
@@ -527,10 +543,13 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
527
543
|
x0_pred = model_output
|
528
544
|
elif self.config.prediction_type == "v_prediction":
|
529
545
|
x0_pred = alpha_t * sample - sigma_t * model_output
|
546
|
+
elif self.config.prediction_type == "flow_prediction":
|
547
|
+
sigma_t = self.sigmas[self.step_index]
|
548
|
+
x0_pred = sample - sigma_t * model_output
|
530
549
|
else:
|
531
550
|
raise ValueError(
|
532
|
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,
|
533
|
-
"
|
551
|
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
552
|
+
"`v_prediction`, or `flow_prediction` for the SASolverScheduler."
|
534
553
|
)
|
535
554
|
|
536
555
|
if self.config.thresholding:
|
@@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
680
680
|
|
681
681
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
682
682
|
def previous_timestep(self, timestep):
|
683
|
-
if self.custom_timesteps:
|
683
|
+
if self.custom_timesteps or self.num_inference_steps:
|
684
684
|
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
685
685
|
if index == self.timesteps.shape[0] - 1:
|
686
686
|
prev_t = torch.tensor(-1)
|
687
687
|
else:
|
688
688
|
prev_t = self.timesteps[index + 1]
|
689
689
|
else:
|
690
|
-
|
691
|
-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
692
|
-
)
|
693
|
-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
694
|
-
|
690
|
+
prev_t = timestep - 1
|
695
691
|
return prev_t
|
@@ -206,6 +206,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
206
206
|
use_karras_sigmas: Optional[bool] = False,
|
207
207
|
use_exponential_sigmas: Optional[bool] = False,
|
208
208
|
use_beta_sigmas: Optional[bool] = False,
|
209
|
+
use_flow_sigmas: Optional[bool] = False,
|
210
|
+
flow_shift: Optional[float] = 1.0,
|
209
211
|
timestep_spacing: str = "linspace",
|
210
212
|
steps_offset: int = 0,
|
211
213
|
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
@@ -347,11 +349,47 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
347
349
|
)
|
348
350
|
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
349
351
|
elif self.config.use_exponential_sigmas:
|
350
|
-
|
352
|
+
log_sigmas = np.log(sigmas)
|
353
|
+
sigmas = np.flip(sigmas).copy()
|
354
|
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
351
355
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
356
|
+
if self.config.final_sigmas_type == "sigma_min":
|
357
|
+
sigma_last = sigmas[-1]
|
358
|
+
elif self.config.final_sigmas_type == "zero":
|
359
|
+
sigma_last = 0
|
360
|
+
else:
|
361
|
+
raise ValueError(
|
362
|
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
363
|
+
)
|
364
|
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
352
365
|
elif self.config.use_beta_sigmas:
|
353
|
-
|
366
|
+
log_sigmas = np.log(sigmas)
|
367
|
+
sigmas = np.flip(sigmas).copy()
|
368
|
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
354
369
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
370
|
+
if self.config.final_sigmas_type == "sigma_min":
|
371
|
+
sigma_last = sigmas[-1]
|
372
|
+
elif self.config.final_sigmas_type == "zero":
|
373
|
+
sigma_last = 0
|
374
|
+
else:
|
375
|
+
raise ValueError(
|
376
|
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
377
|
+
)
|
378
|
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
379
|
+
elif self.config.use_flow_sigmas:
|
380
|
+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
381
|
+
sigmas = 1.0 - alphas
|
382
|
+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
383
|
+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
384
|
+
if self.config.final_sigmas_type == "sigma_min":
|
385
|
+
sigma_last = sigmas[-1]
|
386
|
+
elif self.config.final_sigmas_type == "zero":
|
387
|
+
sigma_last = 0
|
388
|
+
else:
|
389
|
+
raise ValueError(
|
390
|
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
391
|
+
)
|
392
|
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
355
393
|
else:
|
356
394
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
357
395
|
if self.config.final_sigmas_type == "sigma_min":
|
@@ -442,8 +480,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
442
480
|
|
443
481
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
444
482
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
445
|
-
|
446
|
-
|
483
|
+
if self.config.use_flow_sigmas:
|
484
|
+
alpha_t = 1 - sigma
|
485
|
+
sigma_t = sigma
|
486
|
+
else:
|
487
|
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
488
|
+
sigma_t = sigma * alpha_t
|
447
489
|
|
448
490
|
return alpha_t, sigma_t
|
449
491
|
|
@@ -492,7 +534,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
492
534
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
493
535
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
494
536
|
|
495
|
-
sigmas =
|
537
|
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
496
538
|
return sigmas
|
497
539
|
|
498
540
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
@@ -516,7 +558,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
516
558
|
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
517
559
|
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
518
560
|
|
519
|
-
sigmas =
|
561
|
+
sigmas = np.array(
|
520
562
|
[
|
521
563
|
sigma_min + (ppf * (sigma_max - sigma_min))
|
522
564
|
for ppf in [
|
@@ -572,10 +614,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
572
614
|
x0_pred = model_output
|
573
615
|
elif self.config.prediction_type == "v_prediction":
|
574
616
|
x0_pred = alpha_t * sample - sigma_t * model_output
|
617
|
+
elif self.config.prediction_type == "flow_prediction":
|
618
|
+
sigma_t = self.sigmas[self.step_index]
|
619
|
+
x0_pred = sample - sigma_t * model_output
|
575
620
|
else:
|
576
621
|
raise ValueError(
|
577
|
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,
|
578
|
-
"
|
622
|
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
623
|
+
"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
|
579
624
|
)
|
580
625
|
|
581
626
|
if self.config.thresholding:
|
diffusers/training_utils.py
CHANGED
@@ -43,6 +43,9 @@ def set_seed(seed: int):
|
|
43
43
|
|
44
44
|
Args:
|
45
45
|
seed (`int`): The seed to set.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
`None`
|
46
49
|
"""
|
47
50
|
random.seed(seed)
|
48
51
|
np.random.seed(seed)
|
@@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
|
|
58
61
|
"""
|
59
62
|
Computes SNR as per
|
60
63
|
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
64
|
+
for the given timesteps using the provided noise scheduler.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
noise_scheduler (`NoiseScheduler`):
|
68
|
+
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
|
69
|
+
the SNR values.
|
70
|
+
timesteps (`torch.Tensor`):
|
71
|
+
A tensor of timesteps for which the SNR is computed.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
|
61
75
|
"""
|
62
76
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
63
77
|
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
@@ -284,7 +298,7 @@ def free_memory():
|
|
284
298
|
elif torch.backends.mps.is_available():
|
285
299
|
torch.mps.empty_cache()
|
286
300
|
elif is_torch_npu_available():
|
287
|
-
torch_npu.empty_cache()
|
301
|
+
torch_npu.npu.empty_cache()
|
288
302
|
|
289
303
|
|
290
304
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
@@ -379,7 +393,7 @@ class EMAModel:
|
|
379
393
|
|
380
394
|
@classmethod
|
381
395
|
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
|
382
|
-
_, ema_kwargs = model_cls.
|
396
|
+
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
|
383
397
|
model = model_cls.from_pretrained(path)
|
384
398
|
|
385
399
|
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
|