diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2252 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +3 -14
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +293 -8
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1937 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +403 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +37 -15
- diffusers/utils/loading_utils.py +80 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import torch
|
|
22
22
|
|
23
23
|
from ..configuration_utils import ConfigMixin, register_to_config
|
24
24
|
from ..utils import deprecate, logging
|
25
|
+
from ..utils.torch_utils import randn_tensor
|
25
26
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
26
27
|
|
27
28
|
|
@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
108
109
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
109
110
|
`algorithm_type="dpmsolver++"`.
|
110
111
|
algorithm_type (`str`, defaults to `dpmsolver++`):
|
111
|
-
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver`
|
112
|
-
algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
|
113
|
-
implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
|
114
|
-
recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
|
115
|
-
Stable Diffusion.
|
112
|
+
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
|
113
|
+
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
|
114
|
+
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
|
115
|
+
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
|
116
|
+
sampling like in Stable Diffusion.
|
116
117
|
solver_type (`str`, defaults to `midpoint`):
|
117
118
|
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
118
119
|
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
@@ -186,7 +187,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
186
187
|
self.init_noise_sigma = 1.0
|
187
188
|
|
188
189
|
# settings for DPM-Solver
|
189
|
-
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
190
|
+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
|
190
191
|
if algorithm_type == "deis":
|
191
192
|
self.register_to_config(algorithm_type="dpmsolver++")
|
192
193
|
else:
|
@@ -197,7 +198,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
197
198
|
else:
|
198
199
|
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
199
200
|
|
200
|
-
if algorithm_type
|
201
|
+
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
201
202
|
raise ValueError(
|
202
203
|
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
|
203
204
|
)
|
@@ -493,10 +494,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
493
494
|
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
494
495
|
)
|
495
496
|
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
496
|
-
if self.config.algorithm_type
|
497
|
+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
497
498
|
if self.config.prediction_type == "epsilon":
|
498
499
|
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
499
|
-
if self.config.variance_type in ["learned_range"]:
|
500
|
+
if self.config.variance_type in ["learned", "learned_range"]:
|
500
501
|
model_output = model_output[:, :3]
|
501
502
|
sigma = self.sigmas[self.step_index]
|
502
503
|
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
@@ -517,34 +518,43 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
517
518
|
x0_pred = self._threshold_sample(x0_pred)
|
518
519
|
|
519
520
|
return x0_pred
|
521
|
+
|
520
522
|
# DPM-Solver needs to solve an integral of the noise prediction model.
|
521
523
|
elif self.config.algorithm_type == "dpmsolver":
|
522
524
|
if self.config.prediction_type == "epsilon":
|
523
525
|
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
524
|
-
if self.config.variance_type in ["learned_range"]:
|
525
|
-
|
526
|
-
|
526
|
+
if self.config.variance_type in ["learned", "learned_range"]:
|
527
|
+
epsilon = model_output[:, :3]
|
528
|
+
else:
|
529
|
+
epsilon = model_output
|
527
530
|
elif self.config.prediction_type == "sample":
|
528
531
|
sigma = self.sigmas[self.step_index]
|
529
532
|
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
530
533
|
epsilon = (sample - alpha_t * model_output) / sigma_t
|
531
|
-
return epsilon
|
532
534
|
elif self.config.prediction_type == "v_prediction":
|
533
535
|
sigma = self.sigmas[self.step_index]
|
534
536
|
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
535
537
|
epsilon = alpha_t * model_output + sigma_t * sample
|
536
|
-
return epsilon
|
537
538
|
else:
|
538
539
|
raise ValueError(
|
539
540
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
540
541
|
" `v_prediction` for the DPMSolverSinglestepScheduler."
|
541
542
|
)
|
542
543
|
|
544
|
+
if self.config.thresholding:
|
545
|
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
546
|
+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
547
|
+
x0_pred = self._threshold_sample(x0_pred)
|
548
|
+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
549
|
+
|
550
|
+
return epsilon
|
551
|
+
|
543
552
|
def dpm_solver_first_order_update(
|
544
553
|
self,
|
545
554
|
model_output: torch.Tensor,
|
546
555
|
*args,
|
547
556
|
sample: torch.Tensor = None,
|
557
|
+
noise: Optional[torch.Tensor] = None,
|
548
558
|
**kwargs,
|
549
559
|
) -> torch.Tensor:
|
550
560
|
"""
|
@@ -594,6 +604,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
594
604
|
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
595
605
|
elif self.config.algorithm_type == "dpmsolver":
|
596
606
|
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
607
|
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
608
|
+
assert noise is not None
|
609
|
+
x_t = (
|
610
|
+
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
611
|
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
612
|
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
613
|
+
)
|
597
614
|
return x_t
|
598
615
|
|
599
616
|
def singlestep_dpm_solver_second_order_update(
|
@@ -601,6 +618,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
601
618
|
model_output_list: List[torch.Tensor],
|
602
619
|
*args,
|
603
620
|
sample: torch.Tensor = None,
|
621
|
+
noise: Optional[torch.Tensor] = None,
|
604
622
|
**kwargs,
|
605
623
|
) -> torch.Tensor:
|
606
624
|
"""
|
@@ -688,6 +706,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
688
706
|
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
689
707
|
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
690
708
|
)
|
709
|
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
710
|
+
assert noise is not None
|
711
|
+
if self.config.solver_type == "midpoint":
|
712
|
+
x_t = (
|
713
|
+
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
|
714
|
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
715
|
+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
716
|
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
717
|
+
)
|
718
|
+
elif self.config.solver_type == "heun":
|
719
|
+
x_t = (
|
720
|
+
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
|
721
|
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
722
|
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
723
|
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
724
|
+
)
|
691
725
|
return x_t
|
692
726
|
|
693
727
|
def singlestep_dpm_solver_third_order_update(
|
@@ -800,6 +834,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
800
834
|
*args,
|
801
835
|
sample: torch.Tensor = None,
|
802
836
|
order: int = None,
|
837
|
+
noise: Optional[torch.Tensor] = None,
|
803
838
|
**kwargs,
|
804
839
|
) -> torch.Tensor:
|
805
840
|
"""
|
@@ -848,9 +883,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
848
883
|
)
|
849
884
|
|
850
885
|
if order == 1:
|
851
|
-
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
|
886
|
+
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise)
|
852
887
|
elif order == 2:
|
853
|
-
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
|
888
|
+
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
|
854
889
|
elif order == 3:
|
855
890
|
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
|
856
891
|
else:
|
@@ -892,8 +927,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
892
927
|
def step(
|
893
928
|
self,
|
894
929
|
model_output: torch.Tensor,
|
895
|
-
timestep: int,
|
930
|
+
timestep: Union[int, torch.Tensor],
|
896
931
|
sample: torch.Tensor,
|
932
|
+
generator=None,
|
897
933
|
return_dict: bool = True,
|
898
934
|
) -> Union[SchedulerOutput, Tuple]:
|
899
935
|
"""
|
@@ -929,6 +965,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
929
965
|
self.model_outputs[i] = self.model_outputs[i + 1]
|
930
966
|
self.model_outputs[-1] = model_output
|
931
967
|
|
968
|
+
if self.config.algorithm_type == "sde-dpmsolver++":
|
969
|
+
noise = randn_tensor(
|
970
|
+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
971
|
+
)
|
972
|
+
else:
|
973
|
+
noise = None
|
974
|
+
|
932
975
|
order = self.order_list[self.step_index]
|
933
976
|
|
934
977
|
# For img2img denoising might start with order>1 which is not possible
|
@@ -940,9 +983,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
940
983
|
if order == 1:
|
941
984
|
self.sample = sample
|
942
985
|
|
943
|
-
prev_sample = self.singlestep_dpm_solver_update(
|
986
|
+
prev_sample = self.singlestep_dpm_solver_update(
|
987
|
+
self.model_outputs, sample=self.sample, order=order, noise=noise
|
988
|
+
)
|
944
989
|
|
945
|
-
# upon completion increase step index by one
|
990
|
+
# upon completion increase step index by one, noise=noise
|
946
991
|
self._step_index += 1
|
947
992
|
|
948
993
|
if not return_dict:
|
@@ -134,7 +134,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
134
134
|
|
135
135
|
self.timesteps = self.precondition_noise(sigmas)
|
136
136
|
|
137
|
-
self.sigmas =
|
137
|
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
138
138
|
|
139
139
|
# setable values
|
140
140
|
self.num_inference_steps = None
|
@@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
594
594
|
def step(
|
595
595
|
self,
|
596
596
|
model_output: torch.Tensor,
|
597
|
-
timestep: int,
|
597
|
+
timestep: Union[int, torch.Tensor],
|
598
598
|
sample: torch.Tensor,
|
599
599
|
generator=None,
|
600
600
|
return_dict: bool = True,
|
@@ -12,15 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import math
|
15
16
|
from dataclasses import dataclass
|
16
|
-
from typing import Optional, Tuple, Union
|
17
|
+
from typing import List, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import numpy as np
|
19
20
|
import torch
|
20
21
|
|
21
22
|
from ..configuration_utils import ConfigMixin, register_to_config
|
22
23
|
from ..utils import BaseOutput, logging
|
23
|
-
from ..utils.torch_utils import randn_tensor
|
24
24
|
from .scheduling_utils import SchedulerMixin
|
25
25
|
|
26
26
|
|
@@ -66,12 +66,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
66
66
|
self,
|
67
67
|
num_train_timesteps: int = 1000,
|
68
68
|
shift: float = 1.0,
|
69
|
+
use_dynamic_shifting=False,
|
70
|
+
base_shift: Optional[float] = 0.5,
|
71
|
+
max_shift: Optional[float] = 1.15,
|
72
|
+
base_image_seq_len: Optional[int] = 256,
|
73
|
+
max_image_seq_len: Optional[int] = 4096,
|
69
74
|
):
|
70
75
|
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
71
76
|
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
72
77
|
|
73
78
|
sigmas = timesteps / num_train_timesteps
|
74
|
-
|
79
|
+
if not use_dynamic_shifting:
|
80
|
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
81
|
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
75
82
|
|
76
83
|
self.timesteps = sigmas * num_train_timesteps
|
77
84
|
|
@@ -114,7 +121,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
114
121
|
noise: Optional[torch.FloatTensor] = None,
|
115
122
|
) -> torch.FloatTensor:
|
116
123
|
"""
|
117
|
-
|
124
|
+
Forward process in flow-matching
|
118
125
|
|
119
126
|
Args:
|
120
127
|
sample (`torch.FloatTensor`):
|
@@ -126,10 +133,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
126
133
|
`torch.FloatTensor`:
|
127
134
|
A scaled input sample.
|
128
135
|
"""
|
129
|
-
|
130
|
-
|
136
|
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
137
|
+
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
138
|
+
|
139
|
+
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
140
|
+
# mps does not support float64
|
141
|
+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
142
|
+
timestep = timestep.to(sample.device, dtype=torch.float32)
|
143
|
+
else:
|
144
|
+
schedule_timesteps = self.timesteps.to(sample.device)
|
145
|
+
timestep = timestep.to(sample.device)
|
146
|
+
|
147
|
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
148
|
+
if self.begin_index is None:
|
149
|
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
150
|
+
elif self.step_index is not None:
|
151
|
+
# add_noise is called after first denoising step (for inpainting)
|
152
|
+
step_indices = [self.step_index] * timestep.shape[0]
|
153
|
+
else:
|
154
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
155
|
+
step_indices = [self.begin_index] * timestep.shape[0]
|
156
|
+
|
157
|
+
sigma = sigmas[step_indices].flatten()
|
158
|
+
while len(sigma.shape) < len(sample.shape):
|
159
|
+
sigma = sigma.unsqueeze(-1)
|
131
160
|
|
132
|
-
sigma = self.sigmas[self.step_index]
|
133
161
|
sample = sigma * noise + (1.0 - sigma) * sample
|
134
162
|
|
135
163
|
return sample
|
@@ -137,7 +165,16 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
137
165
|
def _sigma_to_t(self, sigma):
|
138
166
|
return sigma * self.config.num_train_timesteps
|
139
167
|
|
140
|
-
def
|
168
|
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
169
|
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
170
|
+
|
171
|
+
def set_timesteps(
|
172
|
+
self,
|
173
|
+
num_inference_steps: int = None,
|
174
|
+
device: Union[str, torch.device] = None,
|
175
|
+
sigmas: Optional[List[float]] = None,
|
176
|
+
mu: Optional[float] = None,
|
177
|
+
):
|
141
178
|
"""
|
142
179
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
143
180
|
|
@@ -147,17 +184,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
147
184
|
device (`str` or `torch.device`, *optional*):
|
148
185
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
149
186
|
"""
|
150
|
-
self.num_inference_steps = num_inference_steps
|
151
187
|
|
152
|
-
|
153
|
-
|
154
|
-
)
|
188
|
+
if self.config.use_dynamic_shifting and mu is None:
|
189
|
+
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
155
190
|
|
156
|
-
sigmas
|
157
|
-
|
158
|
-
|
191
|
+
if sigmas is None:
|
192
|
+
self.num_inference_steps = num_inference_steps
|
193
|
+
timesteps = np.linspace(
|
194
|
+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
195
|
+
)
|
196
|
+
|
197
|
+
sigmas = timesteps / self.config.num_train_timesteps
|
198
|
+
|
199
|
+
if self.config.use_dynamic_shifting:
|
200
|
+
sigmas = self.time_shift(mu, 1.0, sigmas)
|
201
|
+
else:
|
202
|
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
159
203
|
|
204
|
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
160
205
|
timesteps = sigmas * self.config.num_train_timesteps
|
206
|
+
|
161
207
|
self.timesteps = timesteps.to(device=device)
|
162
208
|
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
163
209
|
|
@@ -246,32 +292,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
246
292
|
sample = sample.to(torch.float32)
|
247
293
|
|
248
294
|
sigma = self.sigmas[self.step_index]
|
295
|
+
sigma_next = self.sigmas[self.step_index + 1]
|
249
296
|
|
250
|
-
|
251
|
-
|
252
|
-
noise = randn_tensor(
|
253
|
-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
254
|
-
)
|
255
|
-
|
256
|
-
eps = noise * s_noise
|
257
|
-
sigma_hat = sigma * (gamma + 1)
|
258
|
-
|
259
|
-
if gamma > 0:
|
260
|
-
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
261
|
-
|
262
|
-
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
263
|
-
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
264
|
-
# backwards compatibility
|
265
|
-
|
266
|
-
# if self.config.prediction_type == "vector_field":
|
267
|
-
|
268
|
-
denoised = sample - model_output * sigma
|
269
|
-
# 2. Convert to an ODE derivative
|
270
|
-
derivative = (sample - denoised) / sigma_hat
|
271
|
-
|
272
|
-
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
297
|
+
prev_sample = sample + (sigma_next - sigma) * model_output
|
273
298
|
|
274
|
-
prev_sample = sample + derivative * dt
|
275
299
|
# Cast sample back to model compatible dtype
|
276
300
|
prev_sample = prev_sample.to(model_output.dtype)
|
277
301
|
|